1#include "arg.h"
2#include "common.h"
3#include "llama.h"
4
5#include <vector>
6#include <cstdio>
7
8int main(int argc, char ** argv) {
9 common_params params;
10
11 params.prompt = "The quick brown fox";
12 params.sampling.seed = 1234;
13
14 if (!common_params_parse(argc, argv, params, ex: LLAMA_EXAMPLE_COMMON)) {
15 return 1;
16 }
17
18 if (params.n_parallel == 1) {
19 // the example uses 2 sequences, so when n_parallel == 1, we need to enable unified kv cache
20 printf(format: "%s: n_parallel == 1, enabling unified kv cache\n", __func__);
21 params.kv_unified = true;
22 }
23
24 common_init();
25
26 if (params.n_predict < 0) {
27 params.n_predict = 16;
28 }
29
30 auto n_past = 0;
31
32 std::string result0;
33 std::string result1;
34 std::string result2;
35
36 // init
37 common_init_result llama_init = common_init_from_params(params);
38
39 llama_model * model = llama_init.model.get();
40 llama_context * ctx = llama_init.context.get();
41
42 if (model == nullptr || ctx == nullptr) {
43 fprintf(stderr, format: "%s : failed to init\n", __func__);
44 return 1;
45 }
46
47 auto sparams = llama_sampler_chain_default_params();
48
49 llama_sampler * smpl = llama_sampler_chain_init(params: sparams);
50
51 llama_sampler_chain_add(chain: smpl, smpl: llama_sampler_init_dist(seed: params.sampling.seed));
52
53 // tokenize prompt
54 auto tokens = common_tokenize(ctx, text: params.prompt, add_special: true);
55
56 // prepare the batch
57 llama_batch batch = llama_batch_init(n_tokens: tokens.size(), embd: 0, n_seq_max: 1);
58 for (size_t i = 0; i < tokens.size(); i++) {
59 common_batch_add(batch, id: tokens[i], pos: i, seq_ids: {0}, logits: false);
60 }
61 batch.logits[batch.n_tokens - 1] = true; // generate next token
62
63 // evaluate prompt
64 llama_decode(ctx, batch);
65 n_past += batch.n_tokens;
66
67 // save state (rng, logits, embedding and kv_cache) to file
68 {
69 std::vector<uint8_t> state_mem(llama_state_get_size(ctx));
70 const size_t written = llama_state_get_data(ctx, dst: state_mem.data(), size: state_mem.size());
71
72 FILE *fp_write = fopen(filename: "dump_state.bin", modes: "wb");
73 fwrite(ptr: state_mem.data(), size: 1, n: written, s: fp_write);
74 fclose(stream: fp_write);
75
76 fprintf(stderr, format: "%s : serialized state into %zd out of a maximum of %zd bytes\n", __func__, written, state_mem.size());
77 }
78
79 // save state (last tokens)
80 const auto n_past_saved = n_past;
81
82 // first run
83 printf(format: "\nfirst run: %s", params.prompt.c_str());
84
85 for (auto i = 0; i < params.n_predict; i++) {
86 auto next_token = llama_sampler_sample(smpl, ctx, idx: -1);
87 auto next_token_str = common_token_to_piece(ctx, token: next_token);
88
89 printf(format: "%s", next_token_str.c_str());
90 result0 += next_token_str;
91
92 common_batch_clear(batch);
93 common_batch_add(batch, id: next_token, pos: n_past, seq_ids: {0}, logits: true);
94
95 if (llama_decode(ctx, batch)) {
96 fprintf(stderr, format: "\n%s : failed to evaluate\n", __func__);
97 llama_batch_free(batch);
98 return 1;
99 }
100 n_past += 1;
101 }
102
103 printf(format: "\n\n");
104
105 // make new context
106 llama_context * ctx2 = llama_init_from_model(model, params: common_context_params_to_llama(params));
107
108 llama_sampler * smpl2 = llama_sampler_chain_init(params: sparams);
109
110 llama_sampler_chain_add(chain: smpl2, smpl: llama_sampler_init_dist(seed: params.sampling.seed));
111
112 printf(format: "\nsecond run: %s", params.prompt.c_str());
113
114 // load state (rng, logits, embedding and kv_cache) from file
115 {
116 std::vector<uint8_t> state_mem;
117
118 FILE * fp_read = fopen(filename: "dump_state.bin", modes: "rb");
119 fseek(stream: fp_read, off: 0, SEEK_END);
120 state_mem.resize(new_size: ftell(stream: fp_read));
121 fseek(stream: fp_read, off: 0, SEEK_SET);
122 const size_t read = fread(ptr: state_mem.data(), size: 1, n: state_mem.size(), stream: fp_read);
123 fclose(stream: fp_read);
124
125 if (read != llama_state_set_data(ctx: ctx2, src: state_mem.data(), size: state_mem.size())) {
126 fprintf(stderr, format: "\n%s : failed to read state\n", __func__);
127 return 1;
128 }
129
130 fprintf(stderr, format: "%s : deserialized state from %zd out of a maximum of %zd bytes\n", __func__, read, state_mem.size());
131 }
132
133 // restore state (last tokens)
134 n_past = n_past_saved;
135
136 // second run
137 for (auto i = 0; i < params.n_predict; i++) {
138 auto next_token = llama_sampler_sample(smpl: smpl2, ctx: ctx2, idx: -1);
139 auto next_token_str = common_token_to_piece(ctx: ctx2, token: next_token);
140
141 printf(format: "%s", next_token_str.c_str());
142 result1 += next_token_str;
143
144 common_batch_clear(batch);
145 common_batch_add(batch, id: next_token, pos: n_past, seq_ids: {0}, logits: true);
146
147 if (llama_decode(ctx: ctx2, batch)) {
148 fprintf(stderr, format: "\n%s : failed to evaluate\n", __func__);
149 llama_batch_free(batch);
150 return 1;
151 }
152 n_past += 1;
153 }
154
155 printf(format: "\n\n");
156
157 if (result0 != result1) {
158 fprintf(stderr, format: "\n%s : error : the 2 generations are different\n", __func__);
159 return 1;
160 }
161
162 // make new context
163 llama_context * ctx3 = llama_init_from_model(model, params: common_context_params_to_llama(params));
164
165 llama_sampler * smpl3 = llama_sampler_chain_init(params: sparams);
166
167 llama_sampler_chain_add(chain: smpl3, smpl: llama_sampler_init_dist(seed: params.sampling.seed));
168
169 printf(format: "\nsingle seq run: %s", params.prompt.c_str());
170
171 // load state (rng, logits, embedding and kv_cache) from file
172 {
173 std::vector<uint8_t> state_mem;
174
175 FILE * fp_read = fopen(filename: "dump_state.bin", modes: "rb");
176 fseek(stream: fp_read, off: 0, SEEK_END);
177 state_mem.resize(new_size: ftell(stream: fp_read));
178 fseek(stream: fp_read, off: 0, SEEK_SET);
179 const size_t read = fread(ptr: state_mem.data(), size: 1, n: state_mem.size(), stream: fp_read);
180 fclose(stream: fp_read);
181
182 if (read != llama_state_set_data(ctx: ctx3, src: state_mem.data(), size: state_mem.size())) {
183 fprintf(stderr, format: "\n%s : failed to read state\n", __func__);
184 return 1;
185 }
186
187 fprintf(stderr, format: "%s : deserialized state from %zd out of a maximum of %zd bytes\n", __func__, read, state_mem.size());
188 }
189
190 // restore state (last tokens)
191 n_past = n_past_saved;
192
193 // save seq 0 and load into seq 1
194 {
195 // save kv of seq 0
196 std::vector<uint8_t> seq_store(llama_state_seq_get_size(ctx: ctx3, seq_id: 0));
197 const size_t ncopy = llama_state_seq_get_data(ctx: ctx3, dst: seq_store.data(), size: seq_store.size(), seq_id: 0);
198 if (ncopy != seq_store.size()) {
199 fprintf(stderr, format: "\n%s : seq copy data length %zd does not match expected length %zd\n", __func__, ncopy, seq_store.size());
200 return 1;
201 }
202 fprintf(stderr, format: "%s : seq 0 copied, %zd bytes\n", __func__, ncopy);
203
204 // erase whole kv
205 llama_memory_clear(mem: llama_get_memory(ctx: ctx3), data: true);
206 fprintf(stderr, format: "%s : kv cache cleared\n", __func__);
207
208 // restore kv into seq 1
209 const size_t nset = llama_state_seq_set_data(ctx: ctx3, src: seq_store.data(), size: seq_store.size(), dest_seq_id: 1);
210 if (nset != seq_store.size()) {
211 fprintf(stderr, format: "\n%s : seq set data length %zd does not match expected length %zd\n", __func__, nset, seq_store.size());
212 return 1;
213 }
214 fprintf(stderr, format: "%s : seq 1 restored, %zd bytes\n", __func__, nset);
215 }
216
217 // third run with seq 1 instead of 0
218 for (auto i = 0; i < params.n_predict; i++) {
219 auto next_token = llama_sampler_sample(smpl: smpl3, ctx: ctx3, idx: -1);
220 auto next_token_str = common_token_to_piece(ctx: ctx3, token: next_token);
221
222 printf(format: "%s", next_token_str.c_str());
223 result2 += next_token_str;
224
225 common_batch_clear(batch);
226 common_batch_add(batch, id: next_token, pos: n_past, seq_ids: {1}, logits: true);
227
228 if (llama_decode(ctx: ctx3, batch)) {
229 fprintf(stderr, format: "\n%s : failed to evaluate\n", __func__);
230 llama_batch_free(batch);
231 return 1;
232 }
233 n_past += 1;
234 }
235
236 printf(format: "\n");
237
238 llama_sampler_free(smpl);
239 llama_sampler_free(smpl: smpl2);
240 llama_sampler_free(smpl: smpl3);
241
242 llama_batch_free(batch);
243
244 if (result0 != result2) {
245 fprintf(stderr, format: "\n%s : error : the seq restore generation is different\n", __func__);
246 return 1;
247 }
248
249 fprintf(stderr, format: "\n%s : success\n", __func__);
250
251 return 0;
252}
253