1#include "chat.h"
2#include "common.h"
3#include "llama-cpp.h"
4#include "log.h"
5
6#include "linenoise.cpp/linenoise.h"
7
8#define JSON_ASSERT GGML_ASSERT
9#include <nlohmann/json.hpp>
10
11#if defined(_WIN32)
12# define WIN32_LEAN_AND_MEAN
13# ifndef NOMINMAX
14# define NOMINMAX
15# endif
16# include <windows.h>
17# include <io.h>
18#else
19# include <sys/file.h>
20# include <sys/ioctl.h>
21# include <unistd.h>
22#endif
23
24#if defined(LLAMA_USE_CURL)
25# include <curl/curl.h>
26#else
27# include "http.h"
28#endif
29
30#include <signal.h>
31
32#include <climits>
33#include <cstdarg>
34#include <cstdio>
35#include <cstring>
36#include <filesystem>
37#include <iostream>
38#include <list>
39#include <sstream>
40#include <string>
41#include <vector>
42
43#if defined(__unix__) || (defined(__APPLE__) && defined(__MACH__)) || defined(_WIN32)
44[[noreturn]] static void sigint_handler(int) {
45 printf(format: "\n" LOG_COL_DEFAULT);
46 exit(status: 0); // not ideal, but it's the only way to guarantee exit in all cases
47}
48#endif
49
50GGML_ATTRIBUTE_FORMAT(1, 2)
51static int printe(const char * fmt, ...) {
52 va_list args;
53 va_start(args, fmt);
54 const int ret = vfprintf(stderr, format: fmt, arg: args);
55 va_end(args);
56
57 return ret;
58}
59
60static std::string strftime_fmt(const char * fmt, const std::tm & tm) {
61 std::ostringstream oss;
62 oss << std::put_time(tmb: &tm, fmt: fmt);
63
64 return oss.str();
65}
66
67class Opt {
68 public:
69 int init(int argc, const char ** argv) {
70 ctx_params = llama_context_default_params();
71 model_params = llama_model_default_params();
72 context_size_default = ctx_params.n_batch;
73 n_threads_default = ctx_params.n_threads;
74 ngl_default = model_params.n_gpu_layers;
75 common_params_sampling sampling;
76 temperature_default = sampling.temp;
77
78 if (argc < 2) {
79 printe(fmt: "Error: No arguments provided.\n");
80 print_help();
81 return 1;
82 }
83
84 // Parse arguments
85 if (parse(argc, argv)) {
86 printe(fmt: "Error: Failed to parse arguments.\n");
87 print_help();
88 return 1;
89 }
90
91 // If help is requested, show help and exit
92 if (help) {
93 print_help();
94 return 2;
95 }
96
97 ctx_params.n_batch = context_size >= 0 ? context_size : context_size_default;
98 ctx_params.n_ctx = ctx_params.n_batch;
99 ctx_params.n_threads = ctx_params.n_threads_batch = n_threads >= 0 ? n_threads : n_threads_default;
100 model_params.n_gpu_layers = ngl >= 0 ? ngl : ngl_default;
101 temperature = temperature >= 0 ? temperature : temperature_default;
102
103 return 0; // Success
104 }
105
106 llama_context_params ctx_params;
107 llama_model_params model_params;
108 std::string model_;
109 std::string chat_template_file;
110 std::string user;
111 bool use_jinja = false;
112 int context_size = -1, ngl = -1, n_threads = -1;
113 float temperature = -1;
114 bool verbose = false;
115
116 private:
117 int context_size_default = -1, ngl_default = -1, n_threads_default = -1;
118 float temperature_default = -1;
119 bool help = false;
120
121 bool parse_flag(const char ** argv, int i, const char * short_opt, const char * long_opt) {
122 return strcmp(s1: argv[i], s2: short_opt) == 0 || strcmp(s1: argv[i], s2: long_opt) == 0;
123 }
124
125 int handle_option_with_value(int argc, const char ** argv, int & i, int & option_value) {
126 if (i + 1 >= argc) {
127 return 1;
128 }
129
130 option_value = std::atoi(nptr: argv[++i]);
131
132 return 0;
133 }
134
135 int handle_option_with_value(int argc, const char ** argv, int & i, float & option_value) {
136 if (i + 1 >= argc) {
137 return 1;
138 }
139
140 option_value = std::atof(nptr: argv[++i]);
141
142 return 0;
143 }
144
145 int handle_option_with_value(int argc, const char ** argv, int & i, std::string & option_value) {
146 if (i + 1 >= argc) {
147 return 1;
148 }
149
150 option_value = argv[++i];
151
152 return 0;
153 }
154
155 int parse_options_with_value(int argc, const char ** argv, int & i, bool & options_parsing) {
156 if (options_parsing && (strcmp(s1: argv[i], s2: "-c") == 0 || strcmp(s1: argv[i], s2: "--context-size") == 0)) {
157 if (handle_option_with_value(argc, argv, i, option_value&: context_size) == 1) {
158 return 1;
159 }
160 } else if (options_parsing &&
161 (strcmp(s1: argv[i], s2: "-n") == 0 || strcmp(s1: argv[i], s2: "-ngl") == 0 || strcmp(s1: argv[i], s2: "--ngl") == 0)) {
162 if (handle_option_with_value(argc, argv, i, option_value&: ngl) == 1) {
163 return 1;
164 }
165 } else if (options_parsing && (strcmp(s1: argv[i], s2: "-t") == 0 || strcmp(s1: argv[i], s2: "--threads") == 0)) {
166 if (handle_option_with_value(argc, argv, i, option_value&: n_threads) == 1) {
167 return 1;
168 }
169 } else if (options_parsing && strcmp(s1: argv[i], s2: "--temp") == 0) {
170 if (handle_option_with_value(argc, argv, i, option_value&: temperature) == 1) {
171 return 1;
172 }
173 } else if (options_parsing && strcmp(s1: argv[i], s2: "--chat-template-file") == 0) {
174 if (handle_option_with_value(argc, argv, i, option_value&: chat_template_file) == 1) {
175 return 1;
176 }
177 use_jinja = true;
178 } else {
179 return 2;
180 }
181
182 return 0;
183 }
184
185 int parse_options(const char ** argv, int & i, bool & options_parsing) {
186 if (options_parsing && (parse_flag(argv, i, short_opt: "-v", long_opt: "--verbose") || parse_flag(argv, i, short_opt: "-v", long_opt: "--log-verbose"))) {
187 verbose = true;
188 } else if (options_parsing && strcmp(s1: argv[i], s2: "--jinja") == 0) {
189 use_jinja = true;
190 } else if (options_parsing && parse_flag(argv, i, short_opt: "-h", long_opt: "--help")) {
191 help = true;
192 return 0;
193 } else if (options_parsing && strcmp(s1: argv[i], s2: "--") == 0) {
194 options_parsing = false;
195 } else {
196 return 2;
197 }
198
199 return 0;
200 }
201
202 int parse_positional_args(const char ** argv, int & i, int & positional_args_i) {
203 if (positional_args_i == 0) {
204 if (!argv[i][0] || argv[i][0] == '-') {
205 return 1;
206 }
207
208 ++positional_args_i;
209 model_ = argv[i];
210 } else if (positional_args_i == 1) {
211 ++positional_args_i;
212 user = argv[i];
213 } else {
214 user += " " + std::string(argv[i]);
215 }
216
217 return 0;
218 }
219
220 int parse(int argc, const char ** argv) {
221 bool options_parsing = true;
222 for (int i = 1, positional_args_i = 0; i < argc; ++i) {
223 int ret = parse_options_with_value(argc, argv, i, options_parsing);
224 if (ret == 0) {
225 continue;
226 } else if (ret == 1) {
227 return ret;
228 }
229
230 ret = parse_options(argv, i, options_parsing);
231 if (ret == 0) {
232 continue;
233 } else if (ret == 1) {
234 return ret;
235 }
236
237 if (parse_positional_args(argv, i, positional_args_i)) {
238 return 1;
239 }
240 }
241
242 if (model_.empty()) {
243 return 1;
244 }
245
246 return 0;
247 }
248
249 void print_help() const {
250 printf(
251 format: "Description:\n"
252 " Runs a llm\n"
253 "\n"
254 "Usage:\n"
255 " llama-run [options] model [prompt]\n"
256 "\n"
257 "Options:\n"
258 " -c, --context-size <value>\n"
259 " Context size (default: %d)\n"
260 " --chat-template-file <path>\n"
261 " Path to the file containing the chat template to use with the model.\n"
262 " Only supports jinja templates and implicitly sets the --jinja flag.\n"
263 " --jinja\n"
264 " Use jinja templating for the chat template of the model\n"
265 " -n, -ngl, --ngl <value>\n"
266 " Number of GPU layers (default: %d)\n"
267 " --temp <value>\n"
268 " Temperature (default: %.1f)\n"
269 " -t, --threads <value>\n"
270 " Number of threads to use during generation (default: %d)\n"
271 " -v, --verbose, --log-verbose\n"
272 " Set verbosity level to infinity (i.e. log all messages, useful for debugging)\n"
273 " -h, --help\n"
274 " Show help message\n"
275 "\n"
276 "Commands:\n"
277 " model\n"
278 " Model is a string with an optional prefix of \n"
279 " huggingface:// (hf://), modelscope:// (ms://), ollama://, https:// or file://.\n"
280 " If no protocol is specified and a file exists in the specified\n"
281 " path, file:// is assumed, otherwise if a file does not exist in\n"
282 " the specified path, ollama:// is assumed. Models that are being\n"
283 " pulled are downloaded with .partial extension while being\n"
284 " downloaded and then renamed as the file without the .partial\n"
285 " extension when complete.\n"
286 "\n"
287 "Examples:\n"
288 " llama-run llama3\n"
289 " llama-run ollama://granite-code\n"
290 " llama-run ollama://smollm:135m\n"
291 " llama-run hf://QuantFactory/SmolLM-135M-GGUF/SmolLM-135M.Q2_K.gguf\n"
292 " llama-run "
293 "huggingface://bartowski/SmolLM-1.7B-Instruct-v0.2-GGUF/SmolLM-1.7B-Instruct-v0.2-IQ3_M.gguf\n"
294 " llama-run ms://QuantFactory/SmolLM-135M-GGUF/SmolLM-135M.Q2_K.gguf\n"
295 " llama-run "
296 "modelscope://bartowski/SmolLM-1.7B-Instruct-v0.2-GGUF/SmolLM-1.7B-Instruct-v0.2-IQ3_M.gguf\n"
297 " llama-run https://example.com/some-file1.gguf\n"
298 " llama-run some-file2.gguf\n"
299 " llama-run file://some-file3.gguf\n"
300 " llama-run --ngl 999 some-file4.gguf\n"
301 " llama-run --ngl 999 some-file5.gguf Hello World\n",
302 context_size_default, ngl_default, temperature_default, n_threads_default);
303 }
304};
305
306struct progress_data {
307 size_t file_size = 0;
308 std::chrono::steady_clock::time_point start_time = std::chrono::steady_clock::now();
309 bool printed = false;
310};
311
312static int get_terminal_width() {
313#if defined(_WIN32)
314 CONSOLE_SCREEN_BUFFER_INFO csbi;
315 GetConsoleScreenBufferInfo(GetStdHandle(STD_OUTPUT_HANDLE), &csbi);
316 return csbi.srWindow.Right - csbi.srWindow.Left + 1;
317#else
318 struct winsize w;
319 ioctl(STDOUT_FILENO, TIOCGWINSZ, &w);
320 return w.ws_col;
321#endif
322}
323
324class File {
325 public:
326 FILE * file = nullptr;
327
328 FILE * open(const std::string & filename, const char * mode) {
329 file = ggml_fopen(fname: filename.c_str(), mode);
330
331 return file;
332 }
333
334 int lock() {
335 if (file) {
336# ifdef _WIN32
337 fd = _fileno(file);
338 hFile = (HANDLE) _get_osfhandle(fd);
339 if (hFile == INVALID_HANDLE_VALUE) {
340 fd = -1;
341
342 return 1;
343 }
344
345 OVERLAPPED overlapped = {};
346 if (!LockFileEx(hFile, LOCKFILE_EXCLUSIVE_LOCK | LOCKFILE_FAIL_IMMEDIATELY, 0, MAXDWORD, MAXDWORD,
347 &overlapped)) {
348 fd = -1;
349
350 return 1;
351 }
352# else
353 fd = fileno(stream: file);
354 if (flock(fd: fd, LOCK_EX | LOCK_NB) != 0) {
355 fd = -1;
356
357 return 1;
358 }
359# endif
360 }
361
362 return 0;
363 }
364
365 std::string to_string() {
366 fseek(stream: file, off: 0, SEEK_END);
367 const size_t size = ftell(stream: file);
368 fseek(stream: file, off: 0, SEEK_SET);
369 std::string out;
370 out.resize(n: size);
371 const size_t read_size = fread(ptr: &out[0], size: 1, n: size, stream: file);
372 if (read_size != size) {
373 printe(fmt: "Error reading file: %s", strerror(errno));
374 }
375
376 return out;
377 }
378
379 ~File() {
380 if (fd >= 0) {
381# ifdef _WIN32
382 if (hFile != INVALID_HANDLE_VALUE) {
383 OVERLAPPED overlapped = {};
384 UnlockFileEx(hFile, 0, MAXDWORD, MAXDWORD, &overlapped);
385 }
386# else
387 flock(fd: fd, LOCK_UN);
388# endif
389 }
390
391 if (file) {
392 fclose(stream: file);
393 }
394 }
395
396 private:
397 int fd = -1;
398# ifdef _WIN32
399 HANDLE hFile = nullptr;
400# endif
401};
402
403class HttpClient {
404 public:
405 int init(const std::string & url, const std::vector<std::string> & headers, const std::string & output_file,
406 const bool progress, std::string * response_str = nullptr) {
407 if (std::filesystem::exists(p: output_file)) {
408 return 0;
409 }
410
411 std::string output_file_partial;
412
413 if (!output_file.empty()) {
414 output_file_partial = output_file + ".partial";
415 }
416
417 if (download(url, headers, output_file: output_file_partial, progress, response_str)) {
418 return 1;
419 }
420
421 if (!output_file.empty()) {
422 try {
423 std::filesystem::rename(from: output_file_partial, to: output_file);
424 } catch (const std::filesystem::filesystem_error & e) {
425 printe(fmt: "Failed to rename '%s' to '%s': %s\n", output_file_partial.c_str(), output_file.c_str(), e.what());
426 return 1;
427 }
428 }
429
430 return 0;
431 }
432
433#ifdef LLAMA_USE_CURL
434
435 ~HttpClient() {
436 if (chunk) {
437 curl_slist_free_all(list: chunk);
438 }
439
440 if (curl) {
441 curl_easy_cleanup(curl);
442 }
443 }
444
445 private:
446 CURL * curl = nullptr;
447 struct curl_slist * chunk = nullptr;
448
449 int download(const std::string & url, const std::vector<std::string> & headers, const std::string & output_file,
450 const bool progress, std::string * response_str = nullptr) {
451 curl = curl_easy_init();
452 if (!curl) {
453 return 1;
454 }
455
456 progress_data data;
457 File out;
458 if (!output_file.empty()) {
459 if (!out.open(filename: output_file, mode: "ab")) {
460 printe(fmt: "Failed to open file for writing\n");
461
462 return 1;
463 }
464
465 if (out.lock()) {
466 printe(fmt: "Failed to exclusively lock file\n");
467
468 return 1;
469 }
470 }
471
472 set_write_options(response_str, out);
473 data.file_size = set_resume_point(output_file);
474 set_progress_options(progress, data);
475 set_headers(headers);
476 CURLcode res = perform(url);
477 if (res != CURLE_OK){
478 printe(fmt: "Fetching resource '%s' failed: %s\n", url.c_str(), curl_easy_strerror(res));
479 return 1;
480 }
481
482 return 0;
483 }
484
485 void set_write_options(std::string * response_str, const File & out) {
486 if (response_str) {
487 curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, capture_data);
488 curl_easy_setopt(curl, CURLOPT_WRITEDATA, response_str);
489 } else {
490 curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, write_data);
491 curl_easy_setopt(curl, CURLOPT_WRITEDATA, out.file);
492 }
493 }
494
495 size_t set_resume_point(const std::string & output_file) {
496 size_t file_size = 0;
497 if (std::filesystem::exists(p: output_file)) {
498 file_size = std::filesystem::file_size(p: output_file);
499 curl_easy_setopt(curl, CURLOPT_RESUME_FROM_LARGE, static_cast<curl_off_t>(file_size));
500 }
501
502 return file_size;
503 }
504
505 void set_progress_options(bool progress, progress_data & data) {
506 if (progress) {
507 curl_easy_setopt(curl, CURLOPT_NOPROGRESS, 0L);
508 curl_easy_setopt(curl, CURLOPT_XFERINFODATA, &data);
509 curl_easy_setopt(curl, CURLOPT_XFERINFOFUNCTION, update_progress);
510 }
511 }
512
513 void set_headers(const std::vector<std::string> & headers) {
514 if (!headers.empty()) {
515 if (chunk) {
516 curl_slist_free_all(list: chunk);
517 chunk = 0;
518 }
519
520 for (const auto & header : headers) {
521 chunk = curl_slist_append(list: chunk, data: header.c_str());
522 }
523
524 curl_easy_setopt(curl, CURLOPT_HTTPHEADER, chunk);
525 }
526 }
527
528 CURLcode perform(const std::string & url) {
529 curl_easy_setopt(curl, CURLOPT_URL, url.c_str());
530 curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L);
531 curl_easy_setopt(curl, CURLOPT_DEFAULT_PROTOCOL, "https");
532 curl_easy_setopt(curl, CURLOPT_FAILONERROR, 1L);
533#ifdef _WIN32
534 curl_easy_setopt(curl, CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA);
535#endif
536 return curl_easy_perform(curl);
537 }
538
539#else // LLAMA_USE_CURL is not defined
540
541#define curl_off_t long long // temporary hack
542
543 private:
544 // this is a direct translation of the cURL download() above
545 int download(const std::string & url, const std::vector<std::string> & headers_vec, const std::string & output_file,
546 const bool progress, std::string * response_str = nullptr) {
547 try {
548 auto [cli, url_parts] = common_http_client(url);
549
550 httplib::Headers headers;
551 for (const auto & h : headers_vec) {
552 size_t pos = h.find(':');
553 if (pos != std::string::npos) {
554 headers.emplace(h.substr(0, pos), h.substr(pos + 2));
555 }
556 }
557
558 File out;
559 if (!output_file.empty()) {
560 if (!out.open(output_file, "ab")) {
561 printe("Failed to open file for writing\n");
562 return 1;
563 }
564 if (out.lock()) {
565 printe("Failed to exclusively lock file\n");
566 return 1;
567 }
568 }
569
570 size_t resume_offset = 0;
571 if (!output_file.empty() && std::filesystem::exists(output_file)) {
572 resume_offset = std::filesystem::file_size(output_file);
573 if (resume_offset > 0) {
574 headers.emplace("Range", "bytes=" + std::to_string(resume_offset) + "-");
575 }
576 }
577
578 progress_data data;
579 data.file_size = resume_offset;
580
581 long long total_size = 0;
582 long long received_this_session = 0;
583
584 auto response_handler =
585 [&](const httplib::Response & response) {
586 if (resume_offset > 0 && response.status != 206) {
587 printe("\nServer does not support resuming. Restarting download.\n");
588 out.file = freopen(output_file.c_str(), "wb", out.file);
589 if (!out.file) {
590 return false;
591 }
592 data.file_size = 0;
593 }
594 if (progress) {
595 if (response.has_header("Content-Length")) {
596 total_size = std::stoll(response.get_header_value("Content-Length"));
597 } else if (response.has_header("Content-Range")) {
598 auto range = response.get_header_value("Content-Range");
599 auto slash = range.find('/');
600 if (slash != std::string::npos) {
601 total_size = std::stoll(range.substr(slash + 1));
602 }
603 }
604 }
605 return true;
606 };
607
608 auto content_receiver =
609 [&](const char * chunk, size_t length) {
610 if (out.file && fwrite(chunk, 1, length, out.file) != length) {
611 return false;
612 }
613 if (response_str) {
614 response_str->append(chunk, length);
615 }
616 received_this_session += length;
617
618 if (progress && total_size > 0) {
619 update_progress(&data, total_size, received_this_session, 0, 0);
620 }
621 return true;
622 };
623
624 auto res = cli.Get(url_parts.path, headers, response_handler, content_receiver);
625
626 if (data.printed) {
627 printe("\n");
628 }
629
630 if (!res) {
631 auto err = res.error();
632 printe("Fetching resource '%s' failed: %s\n", url.c_str(), httplib::to_string(err).c_str());
633 return 1;
634 }
635
636 if (res->status >= 400) {
637 printe("Fetching resource '%s' failed with status code: %d\n", url.c_str(), res->status);
638 return 1;
639 }
640
641 } catch (const std::exception & e) {
642 printe("HTTP request failed: %s\n", e.what());
643 return 1;
644 }
645 return 0;
646 }
647
648#endif // LLAMA_USE_CURL
649
650 static std::string human_readable_time(double seconds) {
651 int hrs = static_cast<int>(seconds) / 3600;
652 int mins = (static_cast<int>(seconds) % 3600) / 60;
653 int secs = static_cast<int>(seconds) % 60;
654
655 if (hrs > 0) {
656 return string_format(fmt: "%dh %02dm %02ds", hrs, mins, secs);
657 } else if (mins > 0) {
658 return string_format(fmt: "%dm %02ds", mins, secs);
659 } else {
660 return string_format(fmt: "%ds", secs);
661 }
662 }
663
664 static std::string human_readable_size(curl_off_t size) {
665 static const char * suffix[] = { "B", "KB", "MB", "GB", "TB" };
666 char length = sizeof(suffix) / sizeof(suffix[0]);
667 int i = 0;
668 double dbl_size = size;
669 if (size > 1024) {
670 for (i = 0; (size / 1024) > 0 && i < length - 1; i++, size /= 1024) {
671 dbl_size = size / 1024.0;
672 }
673 }
674
675 return string_format(fmt: "%.2f %s", dbl_size, suffix[i]);
676 }
677
678 static int update_progress(void * ptr, curl_off_t total_to_download, curl_off_t now_downloaded, curl_off_t,
679 curl_off_t) {
680 progress_data * data = static_cast<progress_data *>(ptr);
681 if (total_to_download <= 0) {
682 return 0;
683 }
684
685 total_to_download += data->file_size;
686 const curl_off_t now_downloaded_plus_file_size = now_downloaded + data->file_size;
687 const curl_off_t percentage = calculate_percentage(now_downloaded_plus_file_size, total_to_download);
688 std::string progress_prefix = generate_progress_prefix(percentage);
689
690 const double speed = calculate_speed(now_downloaded, start_time: data->start_time);
691 const double tim = (total_to_download - now_downloaded) / speed;
692 std::string progress_suffix =
693 generate_progress_suffix(now_downloaded_plus_file_size, total_to_download, speed, estimated_time: tim);
694
695 int progress_bar_width = calculate_progress_bar_width(progress_prefix, progress_suffix);
696 std::string progress_bar;
697 generate_progress_bar(progress_bar_width, percentage, progress_bar);
698
699 print_progress(progress_prefix, progress_bar, progress_suffix);
700 data->printed = true;
701
702 return 0;
703 }
704
705 static curl_off_t calculate_percentage(curl_off_t now_downloaded_plus_file_size, curl_off_t total_to_download) {
706 return (now_downloaded_plus_file_size * 100) / total_to_download;
707 }
708
709 static std::string generate_progress_prefix(curl_off_t percentage) {
710 return string_format(fmt: "%3ld%% |", static_cast<long int>(percentage));
711 }
712
713 static double calculate_speed(curl_off_t now_downloaded, const std::chrono::steady_clock::time_point & start_time) {
714 const auto now = std::chrono::steady_clock::now();
715 const std::chrono::duration<double> elapsed_seconds = now - start_time;
716 return now_downloaded / elapsed_seconds.count();
717 }
718
719 static std::string generate_progress_suffix(curl_off_t now_downloaded_plus_file_size, curl_off_t total_to_download,
720 double speed, double estimated_time) {
721 const int width = 10;
722 return string_format(fmt: "%*s/%*s%*s/s%*s", width, human_readable_size(size: now_downloaded_plus_file_size).c_str(),
723 width, human_readable_size(size: total_to_download).c_str(), width,
724 human_readable_size(size: speed).c_str(), width, human_readable_time(seconds: estimated_time).c_str());
725 }
726
727 static int calculate_progress_bar_width(const std::string & progress_prefix, const std::string & progress_suffix) {
728 int progress_bar_width = get_terminal_width() - progress_prefix.size() - progress_suffix.size() - 3;
729 if (progress_bar_width < 1) {
730 progress_bar_width = 1;
731 }
732
733 return progress_bar_width;
734 }
735
736 static std::string generate_progress_bar(int progress_bar_width, curl_off_t percentage,
737 std::string & progress_bar) {
738 const curl_off_t pos = (percentage * progress_bar_width) / 100;
739 for (int i = 0; i < progress_bar_width; ++i) {
740 progress_bar.append(s: (i < pos) ? "â–ˆ" : " ");
741 }
742
743 return progress_bar;
744 }
745
746 static void print_progress(const std::string & progress_prefix, const std::string & progress_bar,
747 const std::string & progress_suffix) {
748 printe(fmt: "\r" LOG_CLR_TO_EOL "%s%s| %s", progress_prefix.c_str(), progress_bar.c_str(), progress_suffix.c_str());
749 }
750 // Function to write data to a file
751 static size_t write_data(void * ptr, size_t size, size_t nmemb, void * stream) {
752 FILE * out = static_cast<FILE *>(stream);
753 return fwrite(ptr: ptr, size: size, n: nmemb, s: out);
754 }
755
756 // Function to capture data into a string
757 static size_t capture_data(void * ptr, size_t size, size_t nmemb, void * stream) {
758 std::string * str = static_cast<std::string *>(stream);
759 str->append(s: static_cast<char *>(ptr), n: size * nmemb);
760 return size * nmemb;
761 }
762
763};
764
765class LlamaData {
766 public:
767 llama_model_ptr model;
768 llama_sampler_ptr sampler;
769 llama_context_ptr context;
770 std::vector<llama_chat_message> messages; // TODO: switch to common_chat_msg
771 std::list<std::string> msg_strs;
772 std::vector<char> fmtted;
773
774 int init(Opt & opt) {
775 model = initialize_model(opt);
776 if (!model) {
777 return 1;
778 }
779
780 context = initialize_context(model, opt);
781 if (!context) {
782 return 1;
783 }
784
785 sampler = initialize_sampler(opt);
786
787 return 0;
788 }
789
790 private:
791 int download(const std::string & url, const std::string & output_file, const bool progress,
792 const std::vector<std::string> & headers = {}, std::string * response_str = nullptr) {
793 HttpClient http;
794 if (http.init(url, headers, output_file, progress, response_str)) {
795 return 1;
796 }
797
798 return 0;
799 }
800
801 // Helper function to handle model tag extraction and URL construction
802 std::pair<std::string, std::string> extract_model_and_tag(std::string & model, const std::string & base_url) {
803 std::string model_tag = "latest";
804 const size_t colon_pos = model.find(c: ':');
805 if (colon_pos != std::string::npos) {
806 model_tag = model.substr(pos: colon_pos + 1);
807 model = model.substr(pos: 0, n: colon_pos);
808 }
809
810 std::string url = base_url + model + "/manifests/" + model_tag;
811
812 return { model, url };
813 }
814
815 // Helper function to download and parse the manifest
816 int download_and_parse_manifest(const std::string & url, const std::vector<std::string> & headers,
817 nlohmann::json & manifest) {
818 std::string manifest_str;
819 int ret = download(url, output_file: "", progress: false, headers, response_str: &manifest_str);
820 if (ret) {
821 return ret;
822 }
823
824 manifest = nlohmann::json::parse(i&: manifest_str);
825
826 return 0;
827 }
828
829 int dl_from_endpoint(std::string & model_endpoint, std::string & model, const std::string & bn) {
830 // Find the second occurrence of '/' after protocol string
831 size_t pos = model.find(c: '/');
832 pos = model.find(c: '/', pos: pos + 1);
833 std::string hfr, hff;
834 std::vector<std::string> headers = { "User-Agent: llama-cpp", "Accept: application/json" };
835 std::string url;
836
837 if (pos == std::string::npos) {
838 auto [model_name, manifest_url] = extract_model_and_tag(model, base_url: model_endpoint + "v2/");
839 hfr = model_name;
840
841 nlohmann::json manifest;
842 int ret = download_and_parse_manifest(url: manifest_url, headers, manifest);
843 if (ret) {
844 return ret;
845 }
846
847 hff = manifest["ggufFile"]["rfilename"];
848 } else {
849 hfr = model.substr(pos: 0, n: pos);
850 hff = model.substr(pos: pos + 1);
851 }
852
853 url = model_endpoint + hfr + "/resolve/main/" + hff;
854
855 return download(url, output_file: bn, progress: true, headers);
856 }
857
858 int modelscope_dl(std::string & model, const std::string & bn) {
859 std::string model_endpoint = "https://modelscope.cn/models/";
860 return dl_from_endpoint(model_endpoint, model, bn);
861 }
862
863 int huggingface_dl(std::string & model, const std::string & bn) {
864 std::string model_endpoint = get_model_endpoint();
865 return dl_from_endpoint(model_endpoint, model, bn);
866 }
867
868 int ollama_dl(std::string & model, const std::string & bn) {
869 const std::vector<std::string> headers = { "Accept: application/vnd.docker.distribution.manifest.v2+json" };
870 if (model.find(c: '/') == std::string::npos) {
871 model = "library/" + model;
872 }
873
874 auto [model_name, manifest_url] = extract_model_and_tag(model, base_url: "https://registry.ollama.ai/v2/");
875 nlohmann::json manifest;
876 int ret = download_and_parse_manifest(url: manifest_url, headers: {}, manifest);
877 if (ret) {
878 return ret;
879 }
880
881 std::string layer;
882 for (const auto & l : manifest["layers"]) {
883 if (l["mediaType"] == "application/vnd.ollama.image.model") {
884 layer = l["digest"];
885 break;
886 }
887 }
888
889 std::string blob_url = "https://registry.ollama.ai/v2/" + model_name + "/blobs/" + layer;
890
891 return download(url: blob_url, output_file: bn, progress: true, headers);
892 }
893
894 int github_dl(const std::string & model, const std::string & bn) {
895 std::string repository = model;
896 std::string branch = "main";
897 const size_t at_pos = model.find(c: '@');
898 if (at_pos != std::string::npos) {
899 repository = model.substr(pos: 0, n: at_pos);
900 branch = model.substr(pos: at_pos + 1);
901 }
902
903 const std::vector<std::string> repo_parts = string_split(str: repository, delimiter: "/");
904 if (repo_parts.size() < 3) {
905 printe(fmt: "Invalid GitHub repository format\n");
906 return 1;
907 }
908
909 const std::string & org = repo_parts[0];
910 const std::string & project = repo_parts[1];
911 std::string url = "https://raw.githubusercontent.com/" + org + "/" + project + "/" + branch;
912 for (size_t i = 2; i < repo_parts.size(); ++i) {
913 url += "/" + repo_parts[i];
914 }
915
916 return download(url, output_file: bn, progress: true);
917 }
918
919 int s3_dl(const std::string & model, const std::string & bn) {
920 const size_t slash_pos = model.find(c: '/');
921 if (slash_pos == std::string::npos) {
922 return 1;
923 }
924
925 const std::string bucket = model.substr(pos: 0, n: slash_pos);
926 const std::string key = model.substr(pos: slash_pos + 1);
927 const char * access_key = std::getenv(name: "AWS_ACCESS_KEY_ID");
928 const char * secret_key = std::getenv(name: "AWS_SECRET_ACCESS_KEY");
929 if (!access_key || !secret_key) {
930 printe(fmt: "AWS credentials not found in environment\n");
931 return 1;
932 }
933
934 // Generate AWS Signature Version 4 headers
935 // (Implementation requires HMAC-SHA256 and date handling)
936 // Get current timestamp
937 const time_t now = time(timer: nullptr);
938 const tm tm = *gmtime(timer: &now);
939 const std::string date = strftime_fmt(fmt: "%Y%m%d", tm);
940 const std::string datetime = strftime_fmt(fmt: "%Y%m%dT%H%M%SZ", tm);
941 const std::vector<std::string> headers = {
942 "Authorization: AWS4-HMAC-SHA256 Credential=" + std::string(access_key) + "/" + date +
943 "/us-east-1/s3/aws4_request",
944 "x-amz-content-sha256: UNSIGNED-PAYLOAD", "x-amz-date: " + datetime
945 };
946
947 const std::string url = "https://" + bucket + ".s3.amazonaws.com/" + key;
948
949 return download(url, output_file: bn, progress: true, headers);
950 }
951
952 std::string basename(const std::string & path) {
953 const size_t pos = path.find_last_of(s: "/\\");
954 if (pos == std::string::npos) {
955 return path;
956 }
957
958 return path.substr(pos: pos + 1);
959 }
960
961 int rm_until_substring(std::string & model_, const std::string & substring) {
962 const std::string::size_type pos = model_.find(str: substring);
963 if (pos == std::string::npos) {
964 return 1;
965 }
966
967 model_ = model_.substr(pos: pos + substring.size()); // Skip past the substring
968 return 0;
969 }
970
971 int resolve_model(std::string & model_) {
972 int ret = 0;
973 if (string_starts_with(str: model_, prefix: "file://") || std::filesystem::exists(p: model_)) {
974 rm_until_substring(model_, substring: "://");
975
976 return ret;
977 }
978
979 const std::string bn = basename(path: model_);
980 if (string_starts_with(str: model_, prefix: "hf://") || string_starts_with(str: model_, prefix: "huggingface://") ||
981 string_starts_with(str: model_, prefix: "hf.co/")) {
982 rm_until_substring(model_, substring: "hf.co/");
983 rm_until_substring(model_, substring: "://");
984 ret = huggingface_dl(model&: model_, bn);
985 } else if (string_starts_with(str: model_, prefix: "ms://") || string_starts_with(str: model_, prefix: "modelscope://")) {
986 rm_until_substring(model_, substring: "://");
987 ret = modelscope_dl(model&: model_, bn);
988 } else if ((string_starts_with(str: model_, prefix: "https://") || string_starts_with(str: model_, prefix: "http://")) &&
989 !string_starts_with(str: model_, prefix: "https://ollama.com/library/")) {
990 ret = download(url: model_, output_file: bn, progress: true);
991 } else if (string_starts_with(str: model_, prefix: "github:") || string_starts_with(str: model_, prefix: "github://")) {
992 rm_until_substring(model_, substring: "github:");
993 rm_until_substring(model_, substring: "://");
994 ret = github_dl(model: model_, bn);
995 } else if (string_starts_with(str: model_, prefix: "s3://")) {
996 rm_until_substring(model_, substring: "://");
997 ret = s3_dl(model: model_, bn);
998 } else { // ollama:// or nothing
999 rm_until_substring(model_, substring: "ollama.com/library/");
1000 rm_until_substring(model_, substring: "://");
1001 ret = ollama_dl(model&: model_, bn);
1002 }
1003
1004 model_ = bn;
1005
1006 return ret;
1007 }
1008
1009 // Initializes the model and returns a unique pointer to it
1010 llama_model_ptr initialize_model(Opt & opt) {
1011 ggml_backend_load_all();
1012 resolve_model(model_&: opt.model_);
1013 printe(fmt: "\r" LOG_CLR_TO_EOL "Loading model");
1014 llama_model_ptr model(llama_model_load_from_file(path_model: opt.model_.c_str(), params: opt.model_params));
1015 if (!model) {
1016 printe(fmt: "%s: error: unable to load model from file: %s\n", __func__, opt.model_.c_str());
1017 }
1018
1019 printe(fmt: "\r" LOG_CLR_TO_EOL);
1020 return model;
1021 }
1022
1023 // Initializes the context with the specified parameters
1024 llama_context_ptr initialize_context(const llama_model_ptr & model, const Opt & opt) {
1025 llama_context_ptr context(llama_init_from_model(model: model.get(), params: opt.ctx_params));
1026 if (!context) {
1027 printe(fmt: "%s: error: failed to create the llama_context\n", __func__);
1028 }
1029
1030 return context;
1031 }
1032
1033 // Initializes and configures the sampler
1034 llama_sampler_ptr initialize_sampler(const Opt & opt) {
1035 llama_sampler_ptr sampler(llama_sampler_chain_init(params: llama_sampler_chain_default_params()));
1036 llama_sampler_chain_add(chain: sampler.get(), smpl: llama_sampler_init_min_p(p: 0.05f, min_keep: 1));
1037 llama_sampler_chain_add(chain: sampler.get(), smpl: llama_sampler_init_temp(t: opt.temperature));
1038 llama_sampler_chain_add(chain: sampler.get(), smpl: llama_sampler_init_dist(LLAMA_DEFAULT_SEED));
1039
1040 return sampler;
1041 }
1042};
1043
1044// Add a message to `messages` and store its content in `msg_strs`
1045static void add_message(const char * role, const std::string & text, LlamaData & llama_data) {
1046 llama_data.msg_strs.push_back(x: std::move(text));
1047 llama_data.messages.push_back(x: { .role: role, .content: llama_data.msg_strs.back().c_str() });
1048}
1049
1050// Function to apply the chat template and resize `formatted` if needed
1051static int apply_chat_template(const struct common_chat_templates * tmpls, LlamaData & llama_data, const bool append, bool use_jinja) {
1052 common_chat_templates_inputs inputs;
1053 for (const auto & msg : llama_data.messages) {
1054 common_chat_msg cmsg;
1055 cmsg.role = msg.role;
1056 cmsg.content = msg.content;
1057 inputs.messages.push_back(x: cmsg);
1058 }
1059 inputs.add_generation_prompt = append;
1060 inputs.use_jinja = use_jinja;
1061
1062 auto chat_params = common_chat_templates_apply(tmpls, inputs);
1063 // TODO: use other params for tool calls.
1064 auto result = chat_params.prompt;
1065 llama_data.fmtted.resize(new_size: result.size() + 1);
1066 memcpy(dest: llama_data.fmtted.data(), src: result.c_str(), n: result.size() + 1);
1067 return result.size();
1068}
1069
1070// Function to tokenize the prompt
1071static int tokenize_prompt(const llama_vocab * vocab, const std::string & prompt,
1072 std::vector<llama_token> & prompt_tokens, const LlamaData & llama_data) {
1073 const bool is_first = llama_memory_seq_pos_max(mem: llama_get_memory(ctx: llama_data.context.get()), seq_id: 0) == -1;
1074 int n_tokens = prompt.size() + 2 * is_first;
1075 prompt_tokens.resize(new_size: n_tokens);
1076 n_tokens = llama_tokenize(vocab, text: prompt.c_str(), text_len: prompt.size(),
1077 tokens: prompt_tokens.data(), n_tokens_max: prompt_tokens.size(),
1078 add_special: is_first, /*parse_special =*/true);
1079 if (n_tokens == std::numeric_limits<int32_t>::min()) {
1080 printe(fmt: "tokenization failed: input too large\n");
1081 return -1;
1082 }
1083 if (n_tokens < 0) {
1084 prompt_tokens.resize(new_size: -n_tokens);
1085 int check = llama_tokenize(vocab, text: prompt.c_str(), text_len: prompt.size(),
1086 tokens: prompt_tokens.data(), n_tokens_max: prompt_tokens.size(),
1087 add_special: is_first, /*parse_special =*/true);
1088 if (check != -n_tokens) {
1089 printe(fmt: "failed to tokenize the prompt (size mismatch)\n");
1090 return -1;
1091 }
1092 n_tokens = check;
1093 } else {
1094 prompt_tokens.resize(new_size: n_tokens);
1095 }
1096 return n_tokens;
1097}
1098
1099// Check if we have enough space in the context to evaluate this batch
1100static int check_context_size(const llama_context_ptr & ctx, const llama_batch & batch) {
1101 const int n_ctx = llama_n_ctx(ctx: ctx.get());
1102 const int n_ctx_used = llama_memory_seq_pos_max(mem: llama_get_memory(ctx: ctx.get()), seq_id: 0);
1103 if (n_ctx_used + batch.n_tokens > n_ctx) {
1104 printf(LOG_COL_DEFAULT "\n");
1105 printe(fmt: "context size exceeded\n");
1106 return 1;
1107 }
1108
1109 return 0;
1110}
1111
1112// convert the token to a string
1113static int convert_token_to_string(const llama_vocab * vocab, const llama_token token_id, std::string & piece) {
1114 char buf[256];
1115 int n = llama_token_to_piece(vocab, token: token_id, buf, length: sizeof(buf), lstrip: 0, special: true);
1116 if (n < 0) {
1117 printe(fmt: "failed to convert token to piece\n");
1118 return 1;
1119 }
1120
1121 piece = std::string(buf, n);
1122 return 0;
1123}
1124
1125static void print_word_and_concatenate_to_response(const std::string & piece, std::string & response) {
1126 printf(format: "%s", piece.c_str());
1127 fflush(stdout);
1128 response += piece;
1129}
1130
1131// helper function to evaluate a prompt and generate a response
1132static int generate(LlamaData & llama_data, const std::string & prompt, std::string & response) {
1133 const llama_vocab * vocab = llama_model_get_vocab(model: llama_data.model.get());
1134
1135 std::vector<llama_token> tokens;
1136 if (tokenize_prompt(vocab, prompt, prompt_tokens&: tokens, llama_data) < 0) {
1137 return 1;
1138 }
1139
1140 // prepare a batch for the prompt
1141 llama_batch batch = llama_batch_get_one(tokens: tokens.data(), n_tokens: tokens.size());
1142 llama_token new_token_id;
1143 while (true) {
1144 check_context_size(ctx: llama_data.context, batch);
1145 if (llama_decode(ctx: llama_data.context.get(), batch)) {
1146 printe(fmt: "failed to decode\n");
1147 return 1;
1148 }
1149
1150 // sample the next token, check is it an end of generation?
1151 new_token_id = llama_sampler_sample(smpl: llama_data.sampler.get(), ctx: llama_data.context.get(), idx: -1);
1152 if (llama_vocab_is_eog(vocab, token: new_token_id)) {
1153 break;
1154 }
1155
1156 std::string piece;
1157 if (convert_token_to_string(vocab, token_id: new_token_id, piece)) {
1158 return 1;
1159 }
1160
1161 print_word_and_concatenate_to_response(piece, response);
1162
1163 // prepare the next batch with the sampled token
1164 batch = llama_batch_get_one(tokens: &new_token_id, n_tokens: 1);
1165 }
1166
1167 printf(LOG_COL_DEFAULT);
1168 return 0;
1169}
1170
1171static int read_user_input(std::string & user_input) {
1172 static const char * prompt_prefix_env = std::getenv(name: "LLAMA_PROMPT_PREFIX");
1173 static const char * prompt_prefix = prompt_prefix_env ? prompt_prefix_env : "> ";
1174#ifdef WIN32
1175 printf("\r" LOG_CLR_TO_EOL LOG_COL_DEFAULT "%s", prompt_prefix);
1176
1177 std::getline(std::cin, user_input);
1178 if (std::cin.eof()) {
1179 printf("\n");
1180 return 1;
1181 }
1182#else
1183 std::unique_ptr<char, decltype(&std::free)> line(const_cast<char *>(linenoise(prompt: prompt_prefix)), free);
1184 if (!line) {
1185 return 1;
1186 }
1187
1188 user_input = line.get();
1189#endif
1190
1191 if (user_input == "/bye") {
1192 return 1;
1193 }
1194
1195 if (user_input.empty()) {
1196 return 2;
1197 }
1198
1199#ifndef WIN32
1200 linenoiseHistoryAdd(line: line.get());
1201#endif
1202
1203 return 0; // Should have data in happy path
1204}
1205
1206// Function to generate a response based on the prompt
1207static int generate_response(LlamaData & llama_data, const std::string & prompt, std::string & response,
1208 const bool stdout_a_terminal) {
1209 // Set response color
1210 if (stdout_a_terminal) {
1211 printf(LOG_COL_YELLOW);
1212 }
1213
1214 if (generate(llama_data, prompt, response)) {
1215 printe(fmt: "failed to generate response\n");
1216 return 1;
1217 }
1218
1219 // End response with color reset and newline
1220 printf(format: "\n%s", stdout_a_terminal ? LOG_COL_DEFAULT : "");
1221 return 0;
1222}
1223
1224// Helper function to apply the chat template and handle errors
1225static int apply_chat_template_with_error_handling(const common_chat_templates * tmpls, LlamaData & llama_data, const bool append, int & output_length, bool use_jinja) {
1226 const int new_len = apply_chat_template(tmpls, llama_data, append, use_jinja);
1227 if (new_len < 0) {
1228 printe(fmt: "failed to apply the chat template\n");
1229 return -1;
1230 }
1231
1232 output_length = new_len;
1233 return 0;
1234}
1235
1236// Helper function to handle user input
1237static int handle_user_input(std::string & user_input, const std::string & user) {
1238 if (!user.empty()) {
1239 user_input = user;
1240 return 0; // No need for interactive input
1241 }
1242
1243 return read_user_input(user_input); // Returns true if input ends the loop
1244}
1245
1246static bool is_stdin_a_terminal() {
1247#if defined(_WIN32)
1248 HANDLE hStdin = GetStdHandle(STD_INPUT_HANDLE);
1249 DWORD mode;
1250 return GetConsoleMode(hStdin, &mode);
1251#else
1252 return isatty(STDIN_FILENO);
1253#endif
1254}
1255
1256static bool is_stdout_a_terminal() {
1257#if defined(_WIN32)
1258 HANDLE hStdout = GetStdHandle(STD_OUTPUT_HANDLE);
1259 DWORD mode;
1260 return GetConsoleMode(hStdout, &mode);
1261#else
1262 return isatty(STDOUT_FILENO);
1263#endif
1264}
1265
1266// Function to handle user input
1267static int get_user_input(std::string & user_input, const std::string & user) {
1268 while (true) {
1269 const int ret = handle_user_input(user_input, user);
1270 if (ret == 1) {
1271 return 1;
1272 }
1273
1274 if (ret == 2) {
1275 continue;
1276 }
1277
1278 break;
1279 }
1280
1281 return 0;
1282}
1283
1284// Reads a chat template file to be used
1285static std::string read_chat_template_file(const std::string & chat_template_file) {
1286 File file;
1287 if (!file.open(filename: chat_template_file, mode: "r")) {
1288 printe(fmt: "Error opening chat template file '%s': %s", chat_template_file.c_str(), strerror(errno));
1289 return "";
1290 }
1291
1292 return file.to_string();
1293}
1294
1295static int process_user_message(const Opt & opt, const std::string & user_input, LlamaData & llama_data,
1296 const common_chat_templates_ptr & chat_templates, int & prev_len,
1297 const bool stdout_a_terminal) {
1298 add_message(role: "user", text: opt.user.empty() ? user_input : opt.user, llama_data);
1299 int new_len;
1300 if (apply_chat_template_with_error_handling(tmpls: chat_templates.get(), llama_data, append: true, output_length&: new_len, use_jinja: opt.use_jinja) < 0) {
1301 return 1;
1302 }
1303
1304 std::string prompt(llama_data.fmtted.begin() + prev_len, llama_data.fmtted.begin() + new_len);
1305 std::string response;
1306 if (generate_response(llama_data, prompt, response, stdout_a_terminal)) {
1307 return 1;
1308 }
1309
1310 if (!opt.user.empty()) {
1311 return 2;
1312 }
1313
1314 add_message(role: "assistant", text: response, llama_data);
1315 if (apply_chat_template_with_error_handling(tmpls: chat_templates.get(), llama_data, append: false, output_length&: prev_len, use_jinja: opt.use_jinja) < 0) {
1316 return 1;
1317 }
1318
1319 return 0;
1320}
1321
1322// Main chat loop function
1323static int chat_loop(LlamaData & llama_data, const Opt & opt) {
1324 int prev_len = 0;
1325 llama_data.fmtted.resize(new_size: llama_n_ctx(ctx: llama_data.context.get()));
1326 std::string chat_template;
1327 if (!opt.chat_template_file.empty()) {
1328 chat_template = read_chat_template_file(chat_template_file: opt.chat_template_file);
1329 }
1330
1331 common_chat_templates_ptr chat_templates = common_chat_templates_init(model: llama_data.model.get(), chat_template_override: chat_template);
1332 static const bool stdout_a_terminal = is_stdout_a_terminal();
1333 while (true) {
1334 // Get user input
1335 std::string user_input;
1336 if (get_user_input(user_input, user: opt.user) == 1) {
1337 return 0;
1338 }
1339
1340 const int ret = process_user_message(opt, user_input, llama_data, chat_templates, prev_len, stdout_a_terminal);
1341 if (ret == 1) {
1342 return 1;
1343 } else if (ret == 2) {
1344 break;
1345 }
1346 }
1347
1348 return 0;
1349}
1350
1351static void log_callback(const enum ggml_log_level level, const char * text, void * p) {
1352 const Opt * opt = static_cast<Opt *>(p);
1353 if (opt->verbose || level == GGML_LOG_LEVEL_ERROR) {
1354 printe(fmt: "%s", text);
1355 }
1356}
1357
1358static std::string read_pipe_data() {
1359 std::ostringstream result;
1360 result << std::cin.rdbuf(); // Read all data from std::cin
1361 return result.str();
1362}
1363
1364static void ctrl_c_handling() {
1365#if defined(__unix__) || (defined(__APPLE__) && defined(__MACH__))
1366 struct sigaction sigint_action;
1367 sigint_action.sa_handler = sigint_handler;
1368 sigemptyset(set: &sigint_action.sa_mask);
1369 sigint_action.sa_flags = 0;
1370 sigaction(SIGINT, act: &sigint_action, NULL);
1371#elif defined(_WIN32)
1372 auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL {
1373 return (ctrl_type == CTRL_C_EVENT) ? (sigint_handler(SIGINT), true) : false;
1374 };
1375 SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
1376#endif
1377}
1378
1379int main(int argc, const char ** argv) {
1380 ctrl_c_handling();
1381 Opt opt;
1382 const int ret = opt.init(argc, argv);
1383 if (ret == 2) {
1384 return 0;
1385 } else if (ret) {
1386 return 1;
1387 }
1388
1389 if (!is_stdin_a_terminal()) {
1390 if (!opt.user.empty()) {
1391 opt.user += "\n\n";
1392 }
1393
1394 opt.user += read_pipe_data();
1395 }
1396
1397 llama_log_set(log_callback, user_data: &opt);
1398 LlamaData llama_data;
1399 if (llama_data.init(opt)) {
1400 return 1;
1401 }
1402
1403 if (chat_loop(llama_data, opt)) {
1404 return 1;
1405 }
1406
1407 return 0;
1408}
1409