1#include "common.h"
2//#include "log.h" // TODO: start using log.h
3#include "llama.h"
4
5#include <cstdio>
6#include <cstring>
7#include <fstream>
8#include <string>
9#include <vector>
10#include <iostream> // TODO: remove me
11
12#if defined(_WIN32)
13#define WIN32_LEAN_AND_MEAN
14#include <windows.h>
15#include <shellapi.h> // For CommandLineToArgvW
16#endif
17
18static void print_usage_information(const char * argv0) {
19 printf(format: "usage: %s [options]\n\n", argv0);
20 printf(format: "The tokenize program tokenizes a prompt using a given model,\n");
21 printf(format: "and prints the resulting tokens to standard output.\n\n");
22 printf(format: "It needs a model file, a prompt, and optionally other flags\n");
23 printf(format: "to control the behavior of the tokenizer.\n\n");
24 printf(format: " The possible options are:\n");
25 printf(format: "\n");
26 printf(format: " -h, --help print this help and exit\n");
27 printf(format: " -m MODEL_PATH, --model MODEL_PATH path to model.\n");
28 printf(format: " --ids if given, only print numerical token IDs, and not token strings.\n");
29 printf(format: " The output format looks like [1, 2, 3], i.e. parseable by Python.\n");
30 printf(format: " -f PROMPT_FNAME, --file PROMPT_FNAME read prompt from a file.\n");
31 printf(format: " -p PROMPT, --prompt PROMPT read prompt from the argument.\n");
32 printf(format: " --stdin read prompt from standard input.\n");
33 printf(format: " --no-bos do not ever add a BOS token to the prompt, even if normally the model uses a BOS token.\n");
34 printf(format: " --no-escape do not escape input (such as \\n, \\t, etc.).\n");
35 printf(format: " --no-parse-special do not parse control tokens.\n");
36 printf(format: " --log-disable disable logs. Makes stderr quiet when loading the model.\n");
37 printf(format: " --show-count print the total number of tokens.\n");
38}
39
40static void llama_log_callback_null(ggml_log_level level, const char * text, void * user_data) {
41 (void) level;
42 (void) text;
43 (void) user_data;
44}
45
46static std::string read_prompt_from_file(const char * filepath, bool & success) {
47 success = false;
48
49 std::ifstream in(filepath, std::ios::binary);
50 if (!in) {
51 fprintf(stderr, format: "%s: could not open file '%s' for reading: %s\n", __func__, filepath, strerror(errno));
52 return std::string();
53 }
54 // do not assume the file is seekable (e.g. /dev/stdin)
55 std::stringstream buffer;
56 buffer << in.rdbuf();
57 if (in.fail()) {
58 fprintf(stderr, format: "%s: could not read the entire file '%s': %s\n", __func__, filepath, strerror(errno));
59 return std::string();
60 }
61
62 success = true;
63 return buffer.str();
64}
65
66//
67// Function: ingest_args(...) -> vector<string>
68//
69// Takes argc and argv arguments, and converts them to a vector of UTF-8 encoded
70// strings, as an STL vector<string>.
71//
72// In particular, it handles character encoding shenanigans on Windows.
73//
74// Note: raw_argc and raw_argv are not actually read at all on Windows.
75// On Windows we call GetCommandLineW to get the arguments in wchar_t
76// format, ignoring the regular argc/argv arguments to main().
77//
78// TODO: potential opportunity to roll common stuff into common/console.cpp
79// in relation to Windows wchar_t shenanigans.
80static std::vector<std::string> ingest_args(int raw_argc, char ** raw_argv) {
81 std::vector<std::string> argv;
82
83 // Handle Windows, if given non-ASCII arguments.
84 // We convert wchar_t arguments into UTF-8 char* on this platform.
85 // Lets you invoke 'tokenize' on Windows cmd.exe with non-ASCII characters
86 // without throwing tantrums.
87#if defined(_WIN32)
88 int argc;
89 const LPWSTR cmdline_wargv = GetCommandLineW();
90 LPWSTR * wargv = CommandLineToArgvW(cmdline_wargv, &argc);
91
92 // silence unused arg warnings
93 (void) raw_argc;
94 (void) raw_argv;
95
96 for (int i = 0; i < argc; ++i) {
97 int length_needed = WideCharToMultiByte(CP_UTF8, 0, wargv[i], wcslen(wargv[i]), 0, 0, NULL, NULL);
98 char * output_buf = (char *) calloc(length_needed+1, sizeof(char));
99 GGML_ASSERT(output_buf);
100
101 WideCharToMultiByte(CP_UTF8, 0, wargv[i], wcslen(wargv[i]), output_buf, length_needed, NULL, NULL);
102 output_buf[length_needed] = '\0';
103
104 argv.push_back(output_buf);
105 free(output_buf);
106 }
107
108 LocalFree((HLOCAL) wargv);
109#else
110 int argc = raw_argc;
111 for (int i = 0; i < argc; ++i) {
112 argv.push_back(x: raw_argv[i]);
113 }
114#endif
115
116 GGML_ASSERT((unsigned int) argc == argv.size());
117
118 return argv;
119}
120
121//
122// Function: write_utf8_cstr_to_stdout(const char *) -> <writes to stdout>
123//
124// writes a string to standard output; taking into account that on Windows
125// to display correctly you have to use special handling. Works even if the
126// user has not set a unicode code page on a Windows cmd.exe.
127//
128// In case of invalid UTF-8, invalid_utf8 is set to true on Windows, and something
129// a human-readable is written instead.
130//
131// On non-Windows systems, simply printfs() the string.
132static void write_utf8_cstr_to_stdout(const char * str, bool & invalid_utf8) {
133 invalid_utf8 = false;
134
135#if defined(_WIN32)
136 // Are we in a console?
137 HANDLE hConsole = GetStdHandle(STD_OUTPUT_HANDLE);
138 DWORD dwMode = 0;
139
140 // According to Microsoft docs:
141 // "WriteConsole fails if it is used with a standard handle that is redirected to a file."
142 // Also according to the docs, you can use GetConsoleMode to check for that.
143 if (hConsole == INVALID_HANDLE_VALUE || !GetConsoleMode(hConsole, &dwMode)) {
144 printf("%s", str);
145 return;
146 }
147
148 // MultiByteToWideChar reports an error if str is empty, don't report
149 // them as invalid_utf8.
150 if (*str == 0) {
151 return;
152 }
153 int length_needed = MultiByteToWideChar(CP_UTF8, MB_ERR_INVALID_CHARS, str, strlen(str), NULL, 0);
154 if (length_needed == 0) {
155 DWORD err = GetLastError();
156 if (err == ERROR_NO_UNICODE_TRANSLATION) {
157 invalid_utf8 = true;
158 int len = strlen(str);
159 printf("<");
160 for (int i = 0; i < len; ++i) {
161 if (i > 0) {
162 printf(" ");
163 }
164 printf("%02x", (uint8_t) str[i]);
165 }
166 printf(">");
167 return;
168 }
169 GGML_ABORT("MultiByteToWideChar() failed in an unexpected way.");
170 }
171
172 LPWSTR wstr = (LPWSTR) calloc(length_needed+1, sizeof(*wstr));
173 GGML_ASSERT(wstr);
174
175 MultiByteToWideChar(CP_UTF8, 0, str, strlen(str), wstr, length_needed);
176 WriteConsoleW(hConsole, wstr, length_needed, NULL, NULL);
177
178 free(wstr);
179#else
180 // TODO: reporting invalid_utf8 would be useful on non-Windows too.
181 // printf will silently just write bad unicode.
182 printf(format: "%s", str);
183#endif
184}
185
186int main(int raw_argc, char ** raw_argv) {
187 const std::vector<std::string> argv = ingest_args(raw_argc, raw_argv);
188 const int argc = argv.size();
189
190 if (argc <= 1) {
191 print_usage_information(argv0: argv[0].c_str());
192 return 1;
193 }
194
195 //////
196 // Read out all the command line arguments.
197 //////
198
199 // variables where to put any arguments we see.
200 bool printing_ids = false;
201 bool no_bos = false;
202 bool no_escape = false;
203 bool no_parse_special = false;
204 bool disable_logging = false;
205 bool show_token_count = false;
206 const char * model_path = NULL;
207 const char * prompt_path = NULL;
208 const char * prompt_arg = NULL;
209
210 // track which arguments were explicitly given
211 // used for sanity checking down the line
212 bool model_path_set = false;
213 bool prompt_path_set = false;
214 bool prompt_set = false;
215 bool stdin_set = false;
216
217 int iarg = 1;
218 for (; iarg < argc; ++iarg) {
219 std::string arg{argv[iarg]};
220 if (arg == "-h" || arg == "--help") {
221 print_usage_information(argv0: argv[0].c_str());
222 return 0;
223 }
224 else if (arg == "--ids") {
225 printing_ids = true;
226 }
227 else if (arg == "-m" || arg == "--model") {
228 if (model_path_set) {
229 fprintf(stderr, format: "Error: -m or --model specified multiple times.\n");
230 return 1;
231 }
232 model_path = argv[++iarg].c_str();
233 model_path_set = true;
234 }
235 else if (arg == "--no-bos") {
236 no_bos = true;
237 }
238 else if (arg == "--no-escape") {
239 no_escape = true;
240 }
241 else if (arg == "--no-parse-special") {
242 no_parse_special = true;
243 }
244 else if (arg == "-p" || arg == "--prompt") {
245 if (prompt_set) {
246 fprintf(stderr, format: "Error: -p or --prompt specified multiple times.\n");
247 return 1;
248 }
249 prompt_arg = argv[++iarg].c_str();
250 prompt_set = true;
251 }
252 else if (arg == "-f" || arg == "--file") {
253 if (prompt_path_set) {
254 fprintf(stderr, format: "Error: -f or --file specified multiple times.\n");
255 return 1;
256 }
257 prompt_path = argv[++iarg].c_str();
258 prompt_path_set = true;
259 }
260 else if (arg == "--stdin") {
261 stdin_set = true;
262 }
263 else if (arg == "--log-disable") {
264 disable_logging = true;
265 }
266 else if (arg == "--show-count") {
267 show_token_count = true;
268 }
269 else {
270 fprintf(stderr, format: "Error: unknown option '%s'\n", argv[iarg].c_str());
271 return 1;
272 }
273 }
274
275 //////
276 // Sanity check the command line arguments.
277 //////
278
279 // Check that we have the required stuff set.
280 if (model_path_set && model_path == NULL) {
281 fprintf(stderr, format: "Error: --model requires an argument.\n");
282 return 1;
283 }
284 if (!model_path_set) {
285 fprintf(stderr, format: "Error: must specify --model.\n");
286 return 1;
287 }
288 if (prompt_path_set && prompt_path == NULL) {
289 fprintf(stderr, format: "Error: --file requires an argument.\n");
290 return 1;
291 }
292 if (prompt_set && prompt_arg == NULL) {
293 fprintf(stderr, format: "Error: --prompt requires an argument.\n");
294 return 1;
295 }
296 const int prompts_set = !!(prompt_path_set) + !!(prompt_set) + !!(stdin_set);
297 if (prompts_set > 1) {
298 fprintf(stderr, format: "Error: --stdin, --file and --prompt are mutually exclusive.\n");
299 return 1;
300 }
301 // Must have some prompt.
302 if (prompts_set == 0) {
303 fprintf(stderr, format: "Error: must specify one of: --stdin, --file or --prompt.\n");
304 return 1;
305 }
306
307 GGML_ASSERT(model_path);
308 GGML_ASSERT(prompt_path || prompt_arg || stdin_set);
309
310 //////
311 // Figure out where will the prompt come from.
312 //////
313
314 std::string prompt;
315 if (prompt_path_set) {
316 bool success = false;
317 prompt = read_prompt_from_file(filepath: prompt_path, success);
318 if (!success) {
319 return 1;
320 }
321 } else if (prompt_set) {
322 prompt = prompt_arg;
323 } else {
324 GGML_ASSERT(stdin_set);
325 // we read stdin *after* loading model (early exit if model cannot
326 // be loaded, which can be a nicer user experience)
327 }
328
329 //////
330 // Start actually doing the tokenizing stuff.
331 //////
332
333 if (disable_logging) {
334 llama_log_set(log_callback: llama_log_callback_null, NULL);
335 }
336
337 llama_backend_init();
338
339 llama_model_params model_params = llama_model_default_params();
340 model_params.vocab_only = true;
341 llama_model * model = llama_model_load_from_file(path_model: model_path, params: model_params);
342 if (!model) {
343 fprintf(stderr, format: "Error: could not load model from file '%s'.\n", model_path);
344 return 1;
345 }
346
347 const llama_vocab * vocab = llama_model_get_vocab(model);
348
349 llama_context_params ctx_params = llama_context_default_params();
350 llama_context * ctx = llama_init_from_model(model, params: ctx_params);
351 if (!ctx) {
352 fprintf(stderr, format: "Error: could not create context.\n");
353 return 1;
354 }
355
356 // read entire prompt from stdin?
357 if (stdin_set) {
358 GGML_ASSERT(!prompt_path_set && !prompt_set);
359
360 std::stringstream stdin_buffer;
361 stdin_buffer << std::cin.rdbuf();
362 if (std::cin.fail()) {
363 fprintf(stderr, format: "Error: could not read the entire standard input.\n");
364 return 1;
365 }
366
367 prompt = stdin_buffer.str();
368 }
369
370 const bool model_wants_add_bos = llama_vocab_get_add_bos(vocab);
371 const bool add_bos = model_wants_add_bos && !no_bos;
372 const bool parse_special = !no_parse_special;
373 const bool escape = !no_escape;
374
375 if (escape) {
376 string_process_escapes(input&: prompt);
377 }
378
379 std::vector<llama_token> tokens;
380 tokens = common_tokenize(vocab, text: prompt, add_special: add_bos, parse_special);
381
382 if (printing_ids) {
383 printf(format: "[");
384 }
385
386 for (int i = 0; i < (int) tokens.size(); i++) {
387 if (printing_ids) {
388 if (i > 0) {
389 printf(format: ", ");
390 }
391 printf(format: "%d", tokens[i]);
392 } else {
393 bool invalid_utf8 = false;
394 printf(format: "%6d -> '", tokens[i]);
395 write_utf8_cstr_to_stdout(str: common_token_to_piece(ctx, token: tokens[i]).c_str(), invalid_utf8);
396 if (invalid_utf8) {
397 printf(format: "' (utf-8 decode failure)\n");
398 } else {
399 printf(format: "'\n");
400 }
401 }
402 }
403
404 if (printing_ids) {
405 printf(format: "]\n");
406 }
407
408 if (show_token_count) {
409 printf(format: "Total number of tokens: %zu\n", tokens.size());
410 }
411 // silence valgrind
412 llama_free(ctx);
413 llama_model_free(model);
414
415 return 0;
416}
417