1#include "llama.h"
2#include <cstdio>
3#include <cstring>
4#include <string>
5#include <vector>
6
7static void print_usage(int, char ** argv) {
8 printf(format: "\nexample usage:\n");
9 printf(format: "\n %s -m model.gguf [-n n_predict] [-ngl n_gpu_layers] [prompt]\n", argv[0]);
10 printf(format: "\n");
11}
12
13int main(int argc, char ** argv) {
14 // path to the model gguf file
15 std::string model_path;
16 // prompt to generate text from
17 std::string prompt = "Hello my name is";
18 // number of layers to offload to the GPU
19 int ngl = 99;
20 // number of tokens to predict
21 int n_predict = 32;
22
23 // parse command line arguments
24
25 {
26 int i = 1;
27 for (; i < argc; i++) {
28 if (strcmp(s1: argv[i], s2: "-m") == 0) {
29 if (i + 1 < argc) {
30 model_path = argv[++i];
31 } else {
32 print_usage(argc, argv);
33 return 1;
34 }
35 } else if (strcmp(s1: argv[i], s2: "-n") == 0) {
36 if (i + 1 < argc) {
37 try {
38 n_predict = std::stoi(str: argv[++i]);
39 } catch (...) {
40 print_usage(argc, argv);
41 return 1;
42 }
43 } else {
44 print_usage(argc, argv);
45 return 1;
46 }
47 } else if (strcmp(s1: argv[i], s2: "-ngl") == 0) {
48 if (i + 1 < argc) {
49 try {
50 ngl = std::stoi(str: argv[++i]);
51 } catch (...) {
52 print_usage(argc, argv);
53 return 1;
54 }
55 } else {
56 print_usage(argc, argv);
57 return 1;
58 }
59 } else {
60 // prompt starts here
61 break;
62 }
63 }
64 if (model_path.empty()) {
65 print_usage(argc, argv);
66 return 1;
67 }
68 if (i < argc) {
69 prompt = argv[i++];
70 for (; i < argc; i++) {
71 prompt += " ";
72 prompt += argv[i];
73 }
74 }
75 }
76
77 // load dynamic backends
78
79 ggml_backend_load_all();
80
81 // initialize the model
82
83 llama_model_params model_params = llama_model_default_params();
84 model_params.n_gpu_layers = ngl;
85
86 llama_model * model = llama_model_load_from_file(path_model: model_path.c_str(), params: model_params);
87
88 if (model == NULL) {
89 fprintf(stderr , format: "%s: error: unable to load model\n" , __func__);
90 return 1;
91 }
92
93 const llama_vocab * vocab = llama_model_get_vocab(model);
94 // tokenize the prompt
95
96 // find the number of tokens in the prompt
97 const int n_prompt = -llama_tokenize(vocab, text: prompt.c_str(), text_len: prompt.size(), NULL, n_tokens_max: 0, add_special: true, parse_special: true);
98
99 // allocate space for the tokens and tokenize the prompt
100 std::vector<llama_token> prompt_tokens(n_prompt);
101 if (llama_tokenize(vocab, text: prompt.c_str(), text_len: prompt.size(), tokens: prompt_tokens.data(), n_tokens_max: prompt_tokens.size(), add_special: true, parse_special: true) < 0) {
102 fprintf(stderr, format: "%s: error: failed to tokenize the prompt\n", __func__);
103 return 1;
104 }
105
106 // initialize the context
107
108 llama_context_params ctx_params = llama_context_default_params();
109 // n_ctx is the context size
110 ctx_params.n_ctx = n_prompt + n_predict - 1;
111 // n_batch is the maximum number of tokens that can be processed in a single call to llama_decode
112 ctx_params.n_batch = n_prompt;
113 // enable performance counters
114 ctx_params.no_perf = false;
115
116 llama_context * ctx = llama_init_from_model(model, params: ctx_params);
117
118 if (ctx == NULL) {
119 fprintf(stderr , format: "%s: error: failed to create the llama_context\n" , __func__);
120 return 1;
121 }
122
123 // initialize the sampler
124
125 auto sparams = llama_sampler_chain_default_params();
126 sparams.no_perf = false;
127 llama_sampler * smpl = llama_sampler_chain_init(params: sparams);
128
129 llama_sampler_chain_add(chain: smpl, smpl: llama_sampler_init_greedy());
130
131 // print the prompt token-by-token
132
133 for (auto id : prompt_tokens) {
134 char buf[128];
135 int n = llama_token_to_piece(vocab, token: id, buf, length: sizeof(buf), lstrip: 0, special: true);
136 if (n < 0) {
137 fprintf(stderr, format: "%s: error: failed to convert token to piece\n", __func__);
138 return 1;
139 }
140 std::string s(buf, n);
141 printf(format: "%s", s.c_str());
142 }
143
144 // prepare a batch for the prompt
145
146 llama_batch batch = llama_batch_get_one(tokens: prompt_tokens.data(), n_tokens: prompt_tokens.size());
147
148 if (llama_model_has_encoder(model)) {
149 if (llama_encode(ctx, batch)) {
150 fprintf(stderr, format: "%s : failed to eval\n", __func__);
151 return 1;
152 }
153
154 llama_token decoder_start_token_id = llama_model_decoder_start_token(model);
155 if (decoder_start_token_id == LLAMA_TOKEN_NULL) {
156 decoder_start_token_id = llama_vocab_bos(vocab);
157 }
158
159 batch = llama_batch_get_one(tokens: &decoder_start_token_id, n_tokens: 1);
160 }
161
162 // main loop
163
164 const auto t_main_start = ggml_time_us();
165 int n_decode = 0;
166 llama_token new_token_id;
167
168 for (int n_pos = 0; n_pos + batch.n_tokens < n_prompt + n_predict; ) {
169 // evaluate the current batch with the transformer model
170 if (llama_decode(ctx, batch)) {
171 fprintf(stderr, format: "%s : failed to eval, return code %d\n", __func__, 1);
172 return 1;
173 }
174
175 n_pos += batch.n_tokens;
176
177 // sample the next token
178 {
179 new_token_id = llama_sampler_sample(smpl, ctx, idx: -1);
180
181 // is it an end of generation?
182 if (llama_vocab_is_eog(vocab, token: new_token_id)) {
183 break;
184 }
185
186 char buf[128];
187 int n = llama_token_to_piece(vocab, token: new_token_id, buf, length: sizeof(buf), lstrip: 0, special: true);
188 if (n < 0) {
189 fprintf(stderr, format: "%s: error: failed to convert token to piece\n", __func__);
190 return 1;
191 }
192 std::string s(buf, n);
193 printf(format: "%s", s.c_str());
194 fflush(stdout);
195
196 // prepare the next batch with the sampled token
197 batch = llama_batch_get_one(tokens: &new_token_id, n_tokens: 1);
198
199 n_decode += 1;
200 }
201 }
202
203 printf(format: "\n");
204
205 const auto t_main_end = ggml_time_us();
206
207 fprintf(stderr, format: "%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n",
208 __func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f));
209
210 fprintf(stderr, format: "\n");
211 llama_perf_sampler_print(chain: smpl);
212 llama_perf_context_print(ctx);
213 fprintf(stderr, format: "\n");
214
215 llama_sampler_free(smpl);
216 llama_free(ctx);
217 llama_model_free(model);
218
219 return 0;
220}
221