1#include "llama.h"
2#include "common.h"
3#include "console.h"
4
5#include <cstdio>
6#include <string>
7#include <map>
8#include <vector>
9#include <fstream>
10#include <thread>
11
12//static const std::map<std::string, std::vector<llama_token>> & k_tests() {
13// static std::map<std::string, std::vector<llama_token>> _k_tests = {
14// { "" , { }, },
15// { " " , { 220, }, },
16// { " " , { 256, }, },
17// { " " , { 262, }, },
18// { "\t" , { 197, }, },
19// { "\n" , { 198, }, },
20// { "\n\n" , { 271, }, },
21// { "\n\n\n" , { 1432, }, },
22// { "\t\n" , { 1602, }, },
23// { "Hello world" , { 9906, 1917, }, },
24// { " Hello world" , { 22691, 1917, }, },
25// { "Hello World" , { 9906, 4435, }, },
26// { " Hello World" , { 22691, 4435, }, },
27// { " Hello World!" , { 22691, 4435, 0, }, },
28// { "Hello, world!" , { 9906, 11, 1917, 0, }, },
29// { " Hello, world!" , { 22691, 11, 1917, 0, }, },
30// { " this is πŸ¦™.cpp" , { 420, 374, 11410, 99, 247, 13, 11055, }, },
31// { "w048 7tuijk dsdfhu" , { 86, 23904, 220, 22, 83, 2005, 42908, 11729, 3013, 17156, }, },
32// { "Π½Π΅Ρ‰ΠΎ Π½Π° Π‘ΡŠΠ»Π³Π°Ρ€ΡΠΊΠΈ" , { 79862, 102118, 13373, 64571, 34694, 3114, 112203, 80112, }, },
33// { "αž€αžΆαž“αŸ‹αžαŸ‚αž–αž·αžŸαŸαžŸαž’αžΆαž…αžαž›αž…αŸαž‰" , { 21549, 222, 98629, 241, 45358, 233, 21549, 237, 45358, 224, 21549, 244, 21549, 115, 21549, 253, 45358, 223, 21549, 253, 21549, 95, 98629, 227, 21549, 223, 21549, 249, 21549, 227, 45358, 223, 21549, 231, }, },
34// { "πŸš€ (normal) πŸ˜Άβ€πŸŒ«οΈ (multiple emojis concatenated) βœ… (only emoji that has its own token)", { 9468, 248, 222, 320, 8416, 8, 27623, 114, 102470, 9468, 234, 104, 31643, 320, 36773, 100166, 98634, 8, 26602, 227, 320, 3323, 43465, 430, 706, 1202, 1866, 4037, 8, }, },
35// { "Hello" , { 9906, }, },
36// { " Hello" , { 22691, }, },
37// { " Hello" , { 220, 22691, }, },
38// { " Hello" , { 256, 22691, }, },
39// { " Hello" , { 262, 22691, }, },
40// { " Hello\n Hello" , { 262, 22691, 198, 262, 22691, }, },
41// { " (" , { 320, }, },
42// { "\n =" , { 198, 284, }, },
43// { "' era" , { 6, 11639, }, },
44// { "Hello, y'all! How are you 😁 ?ζˆ‘ζƒ³εœ¨appleε·₯作1314151倩~", { 9906, 11, 379, 65948, 0, 2650, 527, 499, 27623, 223, 949, 37046, 101067, 19000, 23182, 102301, 9263, 18136, 16, 36827, 21909, }, },
45// { "3" , { 18, }, },
46// { "33" , { 1644, }, },
47// { "333" , { 8765, }, },
48// { "3333" , { 8765, 18, }, },
49// { "33333" , { 8765, 1644, }, },
50// { "333333" , { 8765, 8765, }, },
51// { "3333333" , { 8765, 8765, 18, }, },
52// { "33333333" , { 8765, 8765, 1644, }, },
53// { "333333333" , { 8765, 8765, 8765, }, },
54// };
55//
56// return _k_tests;
57//}
58
59using llama_tests = std::map<std::string, std::vector<llama_token>>;
60
61static llama_tests read_tests(const std::string & fname_inp, const std::string & fname_out) {
62 llama_tests tests;
63
64 std::ifstream ifs_inp(fname_inp);
65 if (!ifs_inp) {
66 fprintf(stderr, format: "%s : error: could not open file '%s'\n", __func__, fname_inp.c_str());
67 return tests;
68 }
69
70 std::string sraw((std::istreambuf_iterator<char>(ifs_inp)), std::istreambuf_iterator<char>());
71
72 std::ifstream ifs_out(fname_out);
73 if (!ifs_out) {
74 fprintf(stderr, format: "%s : error: could not open file '%s'\n", __func__, fname_out.c_str());
75 return tests;
76 }
77
78 std::vector<std::string> sout;
79 for (std::string line; std::getline(is&: ifs_out, str&: line);) {
80 sout.push_back(x: line);
81 }
82
83 const std::string sep = "\n__ggml_vocab_test__\n";
84
85 std::vector<std::string> sinp;
86
87 size_t pos = 0;
88 while (pos < sraw.size()) {
89 const size_t next = sraw.find(str: sep, pos: pos);
90 if (next == std::string::npos) {
91 sinp.push_back(x: sraw.substr(pos: pos));
92 break;
93 }
94 sinp.push_back(x: sraw.substr(pos: pos, n: next - pos));
95 pos = next + sep.size();
96 }
97
98 if (sinp.size() != sout.size()) {
99 fprintf(stderr, format: "%s : error: input and output files have different number of tests\n", __func__);
100 return tests;
101 }
102
103 for (size_t i = 0; i < sinp.size(); ++i) {
104 const std::string & s = sinp[i];
105 const std::string & o = string_strip(str: sout[i]);
106
107 std::vector<llama_token> toks;
108
109 size_t pos = 0;
110 while (pos < o.size()) {
111 size_t next = o.find(c: ' ', pos: pos);
112 if (next == std::string::npos) {
113 next = o.size();
114 }
115 const std::string stok = o.substr(pos: pos, n: next - pos);
116 toks.push_back(x: std::stoi(str: stok));
117 pos = next + 1;
118 }
119
120 tests[s] = toks;
121 }
122
123 return tests;
124}
125
126int main(int argc, char **argv) {
127 if (argc < 2) {
128 fprintf(stderr, format: "Usage: %s vocab-file [text-file]\n", argv[0]);
129 return 1;
130 }
131
132 const std::string fname = argv[1];
133
134 const std::string fname_inp = fname + ".inp";
135 const std::string fname_out = fname + ".out";
136
137 std::string fname_text;
138 if (argc > 2) {
139 fname_text = argv[2];
140 }
141
142 fprintf(stderr, format: "%s : reading vocab from: '%s'\n", __func__, fname.c_str());
143
144 llama_model * model;
145 llama_context * ctx;
146
147 llama_backend_init();
148
149 // load the vocab
150 {
151 auto mparams = llama_model_default_params();
152
153 mparams.vocab_only = true;
154
155 model = llama_model_load_from_file(path_model: fname.c_str(), params: mparams);
156
157 if (model == NULL) {
158 fprintf(stderr, format: "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str());
159 return 1;
160 }
161
162 auto cparams = llama_context_default_params();
163
164 ctx = llama_init_from_model(model, params: cparams);
165
166 if (ctx == NULL) {
167 fprintf(stderr, format: "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str());
168 llama_model_free(model);
169 return 1;
170 }
171 }
172
173#ifdef _WIN32
174 // We need this for unicode console support
175 console::init(false, false);
176 atexit([]() { console::cleanup(); });
177#endif
178
179 bool success = true;
180
181 const auto k_tests = [&]() -> llama_tests {
182 if (!fname_text.empty()) {
183 return {};
184 }
185
186 const auto res = read_tests(fname_inp, fname_out);
187
188 if (res.empty()) {
189 fprintf(stderr, format: "%s : error: no tests found\n", __func__);
190 exit(status: 1);
191 }
192
193 return res;
194 }();
195
196 const bool add_special = false;
197
198 // multi-threaded tokenization
199 const int nthread = std::thread::hardware_concurrency();
200 std::vector<std::thread> threads(nthread);
201
202 for (int i = 0; i < nthread; i++) {
203 threads[i] = std::thread([&, i]() {
204 for (const auto & test_kv : k_tests) {
205 const std::vector<llama_token> res = common_tokenize(ctx, text: test_kv.first, add_special, parse_special: false);
206
207 // here only print the result of the first thread
208 // because the other threads are running the same tests
209 if (i != 0) {
210 continue;
211 }
212
213 printf(format: "\n");
214 printf(format: "src: '%s'\n", test_kv.first.c_str());
215 printf(format: "res: '%s'\n", common_detokenize(ctx, tokens: res).c_str());
216 printf(format: "tok: ");
217 for (const auto & tok : res) {
218 printf(format: "%d ", tok);
219 }
220 printf(format: "\n");
221
222 bool correct = res.size() == test_kv.second.size();
223 for (int i = 0; i < (int) res.size() && correct; ++i) {
224 if (test_kv.second[i] != res[i]) {
225 correct = false;
226 }
227 }
228
229 if (!correct) {
230 fprintf(stderr, format: "%s : failed test: '%s'\n", __func__, test_kv.first.c_str());
231 fprintf(stderr, format: "%s : detokenized to: '%s' instead of '%s'\n", __func__,
232 common_detokenize(ctx, tokens: res).c_str(),
233 common_detokenize(ctx, tokens: test_kv.second).c_str());
234 fprintf(stderr, format: "%s : expected tokens: ", __func__);
235 for (const auto & t : test_kv.second) {
236 fprintf(stderr, format: "%6d '%s', ", t, common_token_to_piece(ctx, token: t).c_str());
237 }
238 fprintf(stderr, format: "\n");
239 fprintf(stderr, format: "%s : got tokens: ", __func__);
240 for (const auto & t : res) {
241 fprintf(stderr, format: "%6d '%s', ", t, common_token_to_piece(ctx, token: t).c_str());
242 }
243 fprintf(stderr, format: "\n");
244
245 success = false;
246 }
247 }
248 });
249 }
250
251 for (int i = 0; i < nthread; i++) {
252 threads[i].join();
253 }
254
255 // single threaded tokenization
256 if (!fname_text.empty()) {
257 fprintf(stderr, format: "%s : tokenizing: '%s'\n", __func__, fname_text.c_str());
258
259 std::string text;
260 {
261 std::ifstream ifs(fname_text);
262 if (!ifs) {
263 fprintf(stderr, format: "%s : error: could not open file '%s'\n", __func__, fname_text.c_str());
264 return 1;
265 }
266 text = std::string(std::istreambuf_iterator<char>(ifs), std::istreambuf_iterator<char>());
267 }
268
269 fprintf(stderr, format: "%s : text size: %zu\n", __func__, text.size());
270
271 std::vector<llama_token> res;
272
273 {
274 const auto t_start = ggml_time_us();
275
276 res = common_tokenize(ctx, text, add_special, parse_special: false);
277
278 const auto t_end = ggml_time_us();
279
280 fprintf(stderr, format: "%s : tokenized in %.3f ms (cpp)\n", __func__, (t_end - t_start) / 1000.0);
281 }
282
283 fprintf(stderr, format: "%s : tokens: %zu\n", __func__, res.size());
284
285 {
286 const std::string fname_out = fname_text + ".tokcpp";
287
288 std::ofstream ofs(fname_out);
289 if (!ofs) {
290 fprintf(stderr, format: "%s : error: could not open file '%s'\n", __func__, fname_out.c_str());
291 return 1;
292 }
293
294 for (const auto & tok : res) {
295 //ofs << tok << " '" << string_strip(llama_detokenize(ctx, std::vector<int>{tok})) << "'" << std::endl;
296 ofs << tok << "\n";
297 }
298 }
299
300 fprintf(stderr, format: "%s : tokens written to '%s'\n", __func__, (fname_text + ".tokcpp").c_str());
301 }
302
303 llama_model_free(model);
304 llama_free(ctx);
305
306 llama_backend_free();
307
308 printf(format: "\n");
309 printf(format: "Tests %s\n", success ? "passed" : "failed");
310
311 return success ? 0 : 3;
312}
313