1#include "arg.h"
2#include "common.h"
3#include "sampling.h"
4#include "log.h"
5#include "llama.h"
6
7#include <cstdio>
8#include <string>
9#include <vector>
10#include <algorithm>
11
12struct ngram_data {
13 bool active = false;
14
15 llama_seq_id seq_id = -1;
16
17 std::vector<int> i_batch;
18
19 std::vector<llama_token> tokens;
20};
21
22// n-gram container
23struct ngram_container {
24 ngram_container(int n_vocab, int N, int G) {
25 cnt.resize(new_size: n_vocab);
26 head.resize(new_size: n_vocab);
27 tokens.resize(new_size: n_vocab * G * (N - 1));
28 }
29
30 int n_total = 0;
31
32 std::vector<int> cnt;
33 std::vector<int> head;
34
35 // [n_vocab][G][N - 1]
36 // for each token of the vocab, keep a ring-buffer of capacity G of n-grams of size N - 1
37 std::vector<llama_token> tokens;
38};
39
40int main(int argc, char ** argv) {
41 common_params params;
42
43 if (!common_params_parse(argc, argv, params, ex: LLAMA_EXAMPLE_COMMON)) {
44 return 1;
45 }
46
47 common_init();
48
49 const int W = 15; // lookahead window
50 const int N = 5; // n-gram size
51 const int G = 15; // max verification n-grams
52
53 // init llama.cpp
54 llama_backend_init();
55 llama_numa_init(numa: params.numa);
56
57 // load the target model
58 common_init_result llama_init = common_init_from_params(params);
59
60 llama_model * model = llama_init.model.get();
61 llama_context * ctx = llama_init.context.get();
62
63 auto * mem = llama_get_memory(ctx);
64
65 const llama_vocab * vocab = llama_model_get_vocab(model);
66
67 // Tokenize the prompt
68 std::vector<llama_token> inp;
69 std::vector<llama_token> all;
70
71 inp = common_tokenize(ctx, text: params.prompt, add_special: true, parse_special: true);
72 all = inp;
73
74 const int max_context_size = llama_n_ctx(ctx);
75 const int max_tokens_list_size = max_context_size - 4;
76
77 if ((int) inp.size() > max_tokens_list_size) {
78 LOG_ERR("%s: prompt too long (%d tokens, max %d)\n", __func__, (int) inp.size(), max_tokens_list_size);
79 return 1;
80 }
81
82 LOG("\n\n");
83
84 for (auto id : inp) {
85 LOG("%s", common_token_to_piece(ctx, id).c_str());
86 }
87
88 fflush(stderr);
89
90 const int n_input = inp.size();
91
92 const auto t_enc_start = ggml_time_us();
93
94 // eval the prompt
95 llama_decode(ctx, batch: llama_batch_get_one( tokens: inp.data(), n_tokens: n_input - 1));
96 llama_decode(ctx, batch: llama_batch_get_one(tokens: &inp.back(), n_tokens: 1));
97
98 for (int s = 1; s < W + G + 1; ++s) {
99 llama_memory_seq_cp(mem, seq_id_src: 0, seq_id_dst: s, p0: -1, p1: -1);
100 }
101
102 const auto t_enc_end = ggml_time_us();
103
104 int n_predict = 0;
105 int n_accept = 0;
106
107 int n_past = inp.size();
108
109 llama_token id = 0;
110
111 // used to determine end of generation
112 bool has_eos = false;
113
114 // for each decoded batch, we have at most W + G + 1 distinct sequences:
115 // seq_id == 0 : the current input token
116 // seq_id [1, W] : tokens from the past N - 1 Jacobi iterations
117 // seq_id [W + 1, W + G] : verification n-grams
118 llama_batch batch = llama_batch_init(n_tokens: params.n_ctx, embd: 0, n_seq_max: W + G + 1);
119
120 // target model sampling context
121 struct common_sampler * smpl = common_sampler_init(model, params: params.sampling);
122
123 // verification n-grams
124 std::vector<ngram_data> ngrams_cur(G);
125
126 // tokens for the past N - 1 Jacobi iterations
127 std::vector<llama_token> tokens_j_prev(W);
128 std::vector<std::vector<llama_token>> tokens_j(N - 1);
129 for (int j = 0; j < N - 1; j++) {
130 tokens_j[j].resize(new_size: W);
131
132 for (int i = 0; i < W; i++) {
133 // there are different ways to init these tokens
134 if (0) {
135 // initialize randomly from the prompt tokens
136 tokens_j[j][i] = all[1 + rand() % (all.size() - 1)];
137 } else {
138 // initialize with a sequence of increasing numbers
139 tokens_j[j][i] = 100 + i;
140 }
141 }
142 }
143
144 std::vector<llama_seq_id> seq_id_look;
145
146 // the input token belongs both to all sequences
147 std::vector<llama_seq_id> seq_id_all(W + G + 1);
148 for (int i = 0; i < W + G + 1; i++) {
149 seq_id_all[i] = i;
150 }
151
152 // here we keep adding new n-grams as we go
153 ngram_container ngrams_observed(llama_vocab_n_tokens(vocab), N, G);
154
155 const auto t_dec_start = ggml_time_us();
156
157 // sample first token
158 {
159 id = common_sampler_sample(gsmpl: smpl, ctx, idx: 0);
160
161 common_sampler_accept(gsmpl: smpl, token: id, accept_grammar: true);
162
163 {
164 const std::string token_str = common_token_to_piece(ctx, token: id);
165
166 LOG("%s", token_str.c_str());
167 fflush(stdout);
168 }
169 }
170
171 while (true) {
172 // build the mask from https://lmsys.org/blog/2023-11-21-lookahead-decoding/
173 //
174 // Example for W = 5, N = 4, G = 2:
175 // (I = input, L = lookahead, V = verification)
176 //
177 // Batch: 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
178 // T: -2 -2 -2 -2 -1 -1 -1 -1 -1 0 0 0 0 0 0
179 // Info: I L L L L L L L L L L L L L L V V V V V V
180 // Pos: 0 1 2 3 4 1 2 3 4 5 2 3 4 5 6 1 2 3 1 2 3 (+ n_past)
181 // Logits: 1 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1
182 // ---------------------------------------------------------------------
183 // Seq: 0
184 // 1 1 1
185 // 2 2 2 2
186 // 3 3 3 3 3
187 // 4 4 4 4 4 4
188 // 5 5 5 5 5 5 5
189 // 6 6 6 6
190 // 7 7 7 7
191 // ---------------------------------------------------------------------
192 // | | | | | | | | | | |
193 // V V V V V | | | | | |
194 // j_tokens | | | | | |
195 // V V V V V V
196 // id
197 {
198 common_batch_clear(batch);
199
200 // current token - first token of the first level
201 common_batch_add(batch, id, pos: n_past, seq_ids: seq_id_all, logits: true);
202
203 // verification n-grams - queue this before the lookahead tokens for less KV cache fragmentation
204 {
205 const int g_cur = ngrams_observed.cnt[id];
206
207 ngrams_cur.resize(new_size: g_cur);
208 for (int g = 0; g < g_cur; g++) {
209 ngrams_cur[g].active = true;
210 ngrams_cur[g].tokens.resize(new_size: N);
211 ngrams_cur[g].i_batch.resize(new_size: N);
212 ngrams_cur[g].seq_id = W + 1 + g;
213 ngrams_cur[g].i_batch[0] = 0;
214 ngrams_cur[g].tokens [0] = id;
215 }
216
217 for (int j = 0; j < N - 1; j++) {
218 for (int g = 0; g < g_cur; g++) {
219 const int idx = id*(N - 1)*G + g*(N - 1);
220
221 const llama_token t = ngrams_observed.tokens[idx + j];
222
223 ngrams_cur[g].tokens [j + 1] = t;
224 ngrams_cur[g].i_batch[j + 1] = batch.n_tokens;
225
226 common_batch_add(batch, id: t, pos: n_past + j + 1, seq_ids: { W + 1 + g }, logits: true);
227 }
228 }
229 }
230
231 // fill the remaining W - 1 tokens for the first level
232 for (int i = 1; i < W; i++) {
233 seq_id_look.resize(new_size: W - i);
234 for (int j = 0; j < W - i; j++) {
235 seq_id_look[j] = i + j + 1;
236 }
237
238 common_batch_add(batch, id: tokens_j[0][i], pos: n_past + i, seq_ids: seq_id_look, logits: false);
239 }
240
241 // fill the rest of the levels
242 for (int j = 1; j < N - 1; j++) {
243 for (int i = 0; i < W; i++) {
244 common_batch_add(batch, id: tokens_j[j][i], pos: n_past + j + i, seq_ids: { i + 1 }, logits: j == N - 2);
245 }
246 }
247 }
248
249 if (llama_decode(ctx, batch) != 0) {
250 LOG_ERR("\n\n%s: llama_decode failed - increase KV cache size\n", __func__);
251 return 1;
252 }
253
254 int seq_id_best = 0;
255
256 for (int v = 0; v < N; ++v) {
257 int i_batch = 0;
258
259 // if no active ngrams are left, it means the sampled token does not pass the verification
260 if (v > 0) {
261 for (int g = 0; g < (int) ngrams_cur.size(); g++) {
262 if (ngrams_cur[g].active) {
263 i_batch = ngrams_cur[g].i_batch[v];
264 seq_id_best = ngrams_cur[g].seq_id;
265
266 ++n_accept;
267 break;
268 }
269 }
270
271 // no more matches -> create a new batch
272 if (i_batch == 0) {
273 break;
274 }
275 }
276
277 // sample the next token
278 id = common_sampler_sample(gsmpl: smpl, ctx, idx: i_batch);
279
280 common_sampler_accept(gsmpl: smpl, token: id, accept_grammar: true);
281
282 // print
283 {
284 const std::string token_str = common_token_to_piece(ctx, token: id);
285
286 if (v == 0) {
287 LOG("%s", token_str.c_str());
288 } else {
289 // print light cyan
290 LOG("\033[0;96m%s\033[0m", token_str.c_str());
291 }
292 fflush(stdout);
293
294 if (llama_vocab_is_eog(vocab, token: id)) {
295 has_eos = true;
296 }
297
298 all.push_back(x: id);
299 }
300
301 ++n_predict;
302 ++n_past;
303
304 if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) {
305 break;
306 }
307
308 // verify across active n-grams
309 for (int g = 0; g < (int) ngrams_cur.size(); g++) {
310 if (ngrams_cur[g].active) {
311 if (v == N - 1) {
312 ngrams_cur[g].active = false;
313 } else {
314 if (id != ngrams_cur[g].tokens[v + 1]) {
315 ngrams_cur[g].active = false;
316 }
317 }
318 }
319 }
320
321 // print known n-grams starting with token id (debug)
322 if (0 && v == 0) {
323 if (ngrams_observed.cnt[id] > 0) {
324 LOG("\n - %d n-grams starting with '%s'\n", ngrams_observed.cnt[id], common_token_to_piece(ctx, id).c_str());
325 }
326
327 for (int i = 0; i < ngrams_observed.cnt[id]; i++) {
328 LOG(" - ngram %2d: ", i);
329
330 const int idx = id*(N - 1)*G + i*(N - 1);
331
332 for (int j = 0; j < N - 1; j++) {
333 const std::string token_str = common_token_to_piece(ctx, token: ngrams_observed.tokens[idx + j]);
334
335 LOG("%s", token_str.c_str());
336 }
337
338 LOG("\n");
339 }
340 }
341
342 // update lookahead tokens
343 {
344 for (int i = 0; i < W; i++) {
345 tokens_j_prev[i] = tokens_j[0][i];
346 }
347
348 for (int j = 0; j < N - 2; j++) {
349 tokens_j[j] = tokens_j[j + 1];
350 }
351
352 if (v == 0) {
353 // sample from the last level
354 for (int i = 0; i < W; i++) {
355 tokens_j[N - 2][i] = common_sampler_sample(gsmpl: smpl, ctx, idx: ngrams_cur.size()*(N-1) + W*(N - 2) + i);
356 }
357 } else {
358 for (int i = 0; i < W; i++) {
359 // there are different ways to init these tokens
360 if (0) {
361 // random init
362 tokens_j[N - 2][i] = all[1 + rand() % (all.size() - 1)];
363 } else {
364 // init from the previous level
365 tokens_j[N - 2][i] = tokens_j[0][i];
366 }
367 }
368 }
369 }
370
371 // update observed ngrams
372 if (v == 0) {
373 // the first token of the n-gram is determined by the index in the container so it is not stored
374 std::vector<llama_token> ngram(N - 1);
375
376 // n-gram generation
377 // ref: https://github.com/hao-ai-lab/LookaheadDecoding/issues/14#issuecomment-1826198518
378 for (int f = 0; f < W; ++f) {
379 const int ft = tokens_j_prev[f]; // first token of the n-gram
380
381 for (int j = 0; j < N - 1; ++j) {
382 ngram[j] = tokens_j[j][f];
383 }
384
385 // filter-out repeating n-grams
386 {
387 bool is_unique = true;
388
389 for (int k = 0; k < ngrams_observed.cnt[ft]; ++k) {
390 const int idx = ft*(N - 1)*G + k*(N - 1);
391
392 bool is_match = true;
393 for (int j = 0; j < N - 1; ++j) {
394 if (ngrams_observed.tokens[idx + j] != ngram[j]) {
395 is_match = false;
396 break;
397 }
398 }
399
400 if (is_match) {
401 is_unique = false;
402 break;
403 }
404 }
405
406 if (!is_unique) {
407 continue;
408 }
409 }
410
411 const int head = ngrams_observed.head[ft];
412 const int idx = ft*(N - 1)*G + head*(N - 1);
413
414 for (int i = 0; i < N - 1; i++) {
415 ngrams_observed.tokens[idx + i] = ngram[i];
416 }
417
418 ngrams_observed.cnt[ft] = std::min(a: G, b: ngrams_observed.cnt[ft] + 1);
419 ngrams_observed.head[ft] = (head + 1) % G;
420
421 ngrams_observed.n_total++;
422 }
423 }
424 }
425
426 if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) {
427 break;
428 }
429
430 // KV cache management
431 // if no verification token matched, we simply remove all cells from this batch -> no fragmentation
432 llama_memory_seq_rm(mem, seq_id: -1, p0: n_past, p1: -1);
433
434 if (seq_id_best != 0) {
435 // if a verification token matched, we keep the best sequence and remove the rest
436 // this leads to some KV cache fragmentation
437 llama_memory_seq_keep(mem, seq_id: seq_id_best);
438 llama_memory_seq_cp (mem, seq_id_src: seq_id_best, seq_id_dst: 0, p0: -1, p1: -1);
439 llama_memory_seq_rm (mem, seq_id: seq_id_best, p0: -1, p1: -1);
440
441 for (int s = 1; s < W + G + 1; ++s) {
442 llama_memory_seq_cp(mem, seq_id_src: 0, seq_id_dst: s, p0: -1, p1: -1);
443 }
444 }
445 }
446
447 auto t_dec_end = ggml_time_us();
448
449 LOG("\n\n");
450
451 LOG_INF("encoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_input, (t_enc_end - t_enc_start) / 1e6f, inp.size() / ((t_enc_end - t_enc_start) / 1e6f));
452 LOG_INF("decoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_predict, (t_dec_end - t_dec_start) / 1e6f, n_predict / ((t_dec_end - t_dec_start) / 1e6f));
453
454 LOG_INF("\n");
455 LOG_INF("W = %2d\n", W);
456 LOG_INF("N = %2d\n", N);
457 LOG_INF("G = %2d\n", G);
458 LOG_INF("\n");
459 LOG_INF("n_predict = %d\n", n_predict);
460 LOG_INF("n_accept = %d\n", n_accept);
461
462 LOG_INF("\n");
463 common_perf_print(ctx, gsmpl: smpl);
464
465 common_sampler_free(gsmpl: smpl);
466
467 llama_batch_free(batch);
468
469 llama_backend_free();
470
471 LOG("\n\n");
472
473 return 0;
474}
475