1#include "arg.h"
2#include "common.h"
3#include "log.h"
4#include "llama.h"
5
6#include <algorithm>
7#include <cstdio>
8#include <string>
9#include <vector>
10
11static void print_usage(int, char ** argv) {
12 LOG("\nexample usage:\n");
13 LOG("\n %s -m model.gguf -p \"Hello my name is\" -n 32 -np 4\n", argv[0]);
14 LOG("\n");
15}
16
17int main(int argc, char ** argv) {
18 common_params params;
19
20 params.prompt = "Hello my name is";
21 params.n_predict = 32;
22
23 if (!common_params_parse(argc, argv, params, ex: LLAMA_EXAMPLE_COMMON, print_usage)) {
24 return 1;
25 }
26
27 common_init();
28
29 // number of parallel batches
30 int n_parallel = params.n_parallel;
31
32 // total length of the sequences including the prompt
33 int n_predict = params.n_predict;
34
35 // init LLM
36
37 llama_backend_init();
38 llama_numa_init(numa: params.numa);
39
40 // initialize the model
41
42 llama_model_params model_params = common_model_params_to_llama(params);
43
44 llama_model * model = llama_model_load_from_file(path_model: params.model.path.c_str(), params: model_params);
45
46 if (model == NULL) {
47 LOG_ERR("%s: error: unable to load model\n" , __func__);
48 return 1;
49 }
50
51 const llama_vocab * vocab = llama_model_get_vocab(model);
52
53 // tokenize the prompt
54
55 std::vector<llama_token> tokens_list;
56 tokens_list = common_tokenize(vocab, text: params.prompt, add_special: true);
57
58 const int n_kv_req = tokens_list.size() + (n_predict - tokens_list.size())*n_parallel;
59
60 // initialize the context
61
62 llama_context_params ctx_params = common_context_params_to_llama(params);
63
64 ctx_params.n_ctx = n_kv_req;
65 ctx_params.n_batch = std::max(a: n_predict, b: n_parallel);
66
67 llama_context * ctx = llama_init_from_model(model, params: ctx_params);
68
69 auto sparams = llama_sampler_chain_default_params();
70 sparams.no_perf = false;
71
72 llama_sampler * smpl = llama_sampler_chain_init(params: sparams);
73
74 llama_sampler_chain_add(chain: smpl, smpl: llama_sampler_init_top_k(k: params.sampling.top_k));
75 llama_sampler_chain_add(chain: smpl, smpl: llama_sampler_init_top_p(p: params.sampling.top_p, min_keep: params.sampling.min_keep));
76 llama_sampler_chain_add(chain: smpl, smpl: llama_sampler_init_temp (t: params.sampling.temp));
77 llama_sampler_chain_add(chain: smpl, smpl: llama_sampler_init_dist (seed: params.sampling.seed));
78
79 if (ctx == NULL) {
80 LOG_ERR("%s: error: failed to create the llama_context\n" , __func__);
81 return 1;
82 }
83
84 const int n_ctx = llama_n_ctx(ctx);
85
86 LOG_INF("\n%s: n_predict = %d, n_ctx = %d, n_batch = %u, n_parallel = %d, n_kv_req = %d\n", __func__, n_predict, n_ctx, ctx_params.n_batch, n_parallel, n_kv_req);
87
88 // make sure the KV cache is big enough to hold all the prompt and generated tokens
89 if (n_kv_req > n_ctx) {
90 LOG_ERR("%s: error: n_kv_req (%d) > n_ctx, the required KV cache size is not big enough\n", __func__, n_kv_req);
91 LOG_ERR("%s: either reduce n_parallel or increase n_ctx\n", __func__);
92 return 1;
93 }
94
95 // print the prompt token-by-token
96
97 LOG("\n");
98
99 for (auto id : tokens_list) {
100 LOG("%s", common_token_to_piece(ctx, id).c_str());
101 }
102
103 // create a llama_batch
104 // we use this object to submit token data for decoding
105 llama_batch batch = llama_batch_init(n_tokens: std::max(a: tokens_list.size(), b: (size_t) n_parallel), embd: 0, n_seq_max: n_parallel);
106
107 std::vector<llama_seq_id> seq_ids(n_parallel, 0);
108 for (int32_t i = 0; i < n_parallel; ++i) {
109 seq_ids[i] = i;
110 }
111
112 // evaluate the initial prompt
113 for (size_t i = 0; i < tokens_list.size(); ++i) {
114 common_batch_add(batch, id: tokens_list[i], pos: i, seq_ids, logits: false);
115 }
116 GGML_ASSERT(batch.n_tokens == (int) tokens_list.size());
117
118 if (llama_model_has_encoder(model)) {
119 if (llama_encode(ctx, batch)) {
120 LOG_ERR("%s : failed to eval\n", __func__);
121 return 1;
122 }
123
124 llama_token decoder_start_token_id = llama_model_decoder_start_token(model);
125 if (decoder_start_token_id == LLAMA_TOKEN_NULL) {
126 decoder_start_token_id = llama_vocab_bos(vocab);
127 }
128
129 common_batch_clear(batch);
130 common_batch_add(batch, id: decoder_start_token_id, pos: 0, seq_ids, logits: false);
131 }
132
133 // llama_decode will output logits only for the last token of the prompt
134 batch.logits[batch.n_tokens - 1] = true;
135
136 if (llama_decode(ctx, batch) != 0) {
137 LOG_ERR("%s: llama_decode() failed\n", __func__);
138 return 1;
139 }
140
141 //// assign the system KV cache to all parallel sequences
142 //// this way, the parallel sequences will "reuse" the prompt tokens without having to copy them
143 //for (int32_t i = 1; i < n_parallel; ++i) {
144 // llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
145 //}
146
147 if (n_parallel > 1) {
148 LOG("\n\n%s: generating %d sequences ...\n", __func__, n_parallel);
149 }
150
151 // main loop
152
153 // we will store the parallel decoded sequences in this vector
154 std::vector<std::string> streams(n_parallel);
155
156 // remember the batch index of the last token for each parallel sequence
157 // we need this to determine which logits to sample from
158 std::vector<int32_t> i_batch(n_parallel, batch.n_tokens - 1);
159
160 int n_cur = batch.n_tokens;
161 int n_decode = 0;
162
163 const auto t_main_start = ggml_time_us();
164
165 while (n_cur <= n_predict) {
166 // prepare the next batch
167 common_batch_clear(batch);
168
169 // sample the next token for each parallel sequence / stream
170 for (int32_t i = 0; i < n_parallel; ++i) {
171 if (i_batch[i] < 0) {
172 // the stream has already finished
173 continue;
174 }
175
176 const llama_token new_token_id = llama_sampler_sample(smpl, ctx, idx: i_batch[i]);
177
178 // is it an end of generation? -> mark the stream as finished
179 if (llama_vocab_is_eog(vocab, token: new_token_id) || n_cur == n_predict) {
180 i_batch[i] = -1;
181 LOG("\n");
182 if (n_parallel > 1) {
183 LOG_INF("%s: stream %d finished at n_cur = %d", __func__, i, n_cur);
184 }
185
186 continue;
187 }
188
189 // if there is only one stream, we print immediately to stdout
190 if (n_parallel == 1) {
191 LOG("%s", common_token_to_piece(ctx, new_token_id).c_str());
192 }
193
194 streams[i] += common_token_to_piece(ctx, token: new_token_id);
195
196 i_batch[i] = batch.n_tokens;
197
198 // push this new token for next evaluation
199 common_batch_add(batch, id: new_token_id, pos: n_cur, seq_ids: { i }, logits: true);
200
201 n_decode += 1;
202 }
203
204 // all streams are finished
205 if (batch.n_tokens == 0) {
206 break;
207 }
208
209 n_cur += 1;
210
211 // evaluate the current batch with the transformer model
212 if (llama_decode(ctx, batch)) {
213 LOG_ERR("%s : failed to eval, return code %d\n", __func__, 1);
214 return 1;
215 }
216 }
217
218 if (n_parallel > 1) {
219 LOG("\n");
220
221 for (int32_t i = 0; i < n_parallel; ++i) {
222 LOG("sequence %d:\n\n%s%s\n\n", i, params.prompt.c_str(), streams[i].c_str());
223 }
224 }
225
226 const auto t_main_end = ggml_time_us();
227
228 LOG_INF("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n",
229 __func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f));
230
231 LOG("\n");
232 llama_perf_sampler_print(chain: smpl);
233 llama_perf_context_print(ctx);
234
235 fprintf(stderr, format: "\n");
236
237 llama_batch_free(batch);
238
239 llama_sampler_free(smpl);
240 llama_free(ctx);
241 llama_model_free(model);
242
243 llama_backend_free();
244
245 return 0;
246}
247