1#include "llama.h"
2#include <cstdio>
3#include <cstring>
4#include <iostream>
5#include <string>
6#include <vector>
7
8static void print_usage(int, char ** argv) {
9 printf(format: "\nexample usage:\n");
10 printf(format: "\n %s -m model.gguf [-c context_size] [-ngl n_gpu_layers]\n", argv[0]);
11 printf(format: "\n");
12}
13
14int main(int argc, char ** argv) {
15 std::string model_path;
16 int ngl = 99;
17 int n_ctx = 2048;
18
19 // parse command line arguments
20 for (int i = 1; i < argc; i++) {
21 try {
22 if (strcmp(s1: argv[i], s2: "-m") == 0) {
23 if (i + 1 < argc) {
24 model_path = argv[++i];
25 } else {
26 print_usage(argc, argv);
27 return 1;
28 }
29 } else if (strcmp(s1: argv[i], s2: "-c") == 0) {
30 if (i + 1 < argc) {
31 n_ctx = std::stoi(str: argv[++i]);
32 } else {
33 print_usage(argc, argv);
34 return 1;
35 }
36 } else if (strcmp(s1: argv[i], s2: "-ngl") == 0) {
37 if (i + 1 < argc) {
38 ngl = std::stoi(str: argv[++i]);
39 } else {
40 print_usage(argc, argv);
41 return 1;
42 }
43 } else {
44 print_usage(argc, argv);
45 return 1;
46 }
47 } catch (std::exception & e) {
48 fprintf(stderr, format: "error: %s\n", e.what());
49 print_usage(argc, argv);
50 return 1;
51 }
52 }
53 if (model_path.empty()) {
54 print_usage(argc, argv);
55 return 1;
56 }
57
58 // only print errors
59 llama_log_set(log_callback: [](enum ggml_log_level level, const char * text, void * /* user_data */) {
60 if (level >= GGML_LOG_LEVEL_ERROR) {
61 fprintf(stderr, format: "%s", text);
62 }
63 }, user_data: nullptr);
64
65 // load dynamic backends
66 ggml_backend_load_all();
67
68 // initialize the model
69 llama_model_params model_params = llama_model_default_params();
70 model_params.n_gpu_layers = ngl;
71
72 llama_model * model = llama_model_load_from_file(path_model: model_path.c_str(), params: model_params);
73 if (!model) {
74 fprintf(stderr , format: "%s: error: unable to load model\n" , __func__);
75 return 1;
76 }
77
78 const llama_vocab * vocab = llama_model_get_vocab(model);
79
80 // initialize the context
81 llama_context_params ctx_params = llama_context_default_params();
82 ctx_params.n_ctx = n_ctx;
83 ctx_params.n_batch = n_ctx;
84
85 llama_context * ctx = llama_init_from_model(model, params: ctx_params);
86 if (!ctx) {
87 fprintf(stderr , format: "%s: error: failed to create the llama_context\n" , __func__);
88 return 1;
89 }
90
91 // initialize the sampler
92 llama_sampler * smpl = llama_sampler_chain_init(params: llama_sampler_chain_default_params());
93 llama_sampler_chain_add(chain: smpl, smpl: llama_sampler_init_min_p(p: 0.05f, min_keep: 1));
94 llama_sampler_chain_add(chain: smpl, smpl: llama_sampler_init_temp(t: 0.8f));
95 llama_sampler_chain_add(chain: smpl, smpl: llama_sampler_init_dist(LLAMA_DEFAULT_SEED));
96
97 // helper function to evaluate a prompt and generate a response
98 auto generate = [&](const std::string & prompt) {
99 std::string response;
100
101 const bool is_first = llama_memory_seq_pos_max(mem: llama_get_memory(ctx), seq_id: 0) == -1;
102
103 // tokenize the prompt
104 const int n_prompt_tokens = -llama_tokenize(vocab, text: prompt.c_str(), text_len: prompt.size(), NULL, n_tokens_max: 0, add_special: is_first, parse_special: true);
105 std::vector<llama_token> prompt_tokens(n_prompt_tokens);
106 if (llama_tokenize(vocab, text: prompt.c_str(), text_len: prompt.size(), tokens: prompt_tokens.data(), n_tokens_max: prompt_tokens.size(), add_special: is_first, parse_special: true) < 0) {
107 GGML_ABORT("failed to tokenize the prompt\n");
108 }
109
110 // prepare a batch for the prompt
111 llama_batch batch = llama_batch_get_one(tokens: prompt_tokens.data(), n_tokens: prompt_tokens.size());
112 llama_token new_token_id;
113 while (true) {
114 // check if we have enough space in the context to evaluate this batch
115 int n_ctx = llama_n_ctx(ctx);
116 int n_ctx_used = llama_memory_seq_pos_max(mem: llama_get_memory(ctx), seq_id: 0) + 1;
117 if (n_ctx_used + batch.n_tokens > n_ctx) {
118 printf(format: "\033[0m\n");
119 fprintf(stderr, format: "context size exceeded\n");
120 exit(status: 0);
121 }
122
123 int ret = llama_decode(ctx, batch);
124 if (ret != 0) {
125 GGML_ABORT("failed to decode, ret = %d\n", ret);
126 }
127
128 // sample the next token
129 new_token_id = llama_sampler_sample(smpl, ctx, idx: -1);
130
131 // is it an end of generation?
132 if (llama_vocab_is_eog(vocab, token: new_token_id)) {
133 break;
134 }
135
136 // convert the token to a string, print it and add it to the response
137 char buf[256];
138 int n = llama_token_to_piece(vocab, token: new_token_id, buf, length: sizeof(buf), lstrip: 0, special: true);
139 if (n < 0) {
140 GGML_ABORT("failed to convert token to piece\n");
141 }
142 std::string piece(buf, n);
143 printf(format: "%s", piece.c_str());
144 fflush(stdout);
145 response += piece;
146
147 // prepare the next batch with the sampled token
148 batch = llama_batch_get_one(tokens: &new_token_id, n_tokens: 1);
149 }
150
151 return response;
152 };
153
154 std::vector<llama_chat_message> messages;
155 std::vector<char> formatted(llama_n_ctx(ctx));
156 int prev_len = 0;
157 while (true) {
158 // get user input
159 printf(format: "\033[32m> \033[0m");
160 std::string user;
161 std::getline(is&: std::cin, str&: user);
162
163 if (user.empty()) {
164 break;
165 }
166
167 const char * tmpl = llama_model_chat_template(model, /* name */ nullptr);
168
169 // add the user input to the message list and format it
170 messages.push_back(x: {.role: "user", .content: strdup(s: user.c_str())});
171 int new_len = llama_chat_apply_template(tmpl, chat: messages.data(), n_msg: messages.size(), add_ass: true, buf: formatted.data(), length: formatted.size());
172 if (new_len > (int)formatted.size()) {
173 formatted.resize(new_size: new_len);
174 new_len = llama_chat_apply_template(tmpl, chat: messages.data(), n_msg: messages.size(), add_ass: true, buf: formatted.data(), length: formatted.size());
175 }
176 if (new_len < 0) {
177 fprintf(stderr, format: "failed to apply the chat template\n");
178 return 1;
179 }
180
181 // remove previous messages to obtain the prompt to generate the response
182 std::string prompt(formatted.begin() + prev_len, formatted.begin() + new_len);
183
184 // generate a response
185 printf(format: "\033[33m");
186 std::string response = generate(prompt);
187 printf(format: "\n\033[0m");
188
189 // add the response to the messages
190 messages.push_back(x: {.role: "assistant", .content: strdup(s: response.c_str())});
191 prev_len = llama_chat_apply_template(tmpl, chat: messages.data(), n_msg: messages.size(), add_ass: false, buf: nullptr, length: 0);
192 if (prev_len < 0) {
193 fprintf(stderr, format: "failed to apply the chat template\n");
194 return 1;
195 }
196 }
197
198 // free resources
199 for (auto & msg : messages) {
200 free(ptr: const_cast<char *>(msg.content));
201 }
202 llama_sampler_free(smpl);
203 llama_free(ctx);
204 llama_model_free(model);
205
206 return 0;
207}
208