1#include "arg.h"
2#include "common.h"
3
4#include <string>
5#include <vector>
6#include <sstream>
7#include <unordered_set>
8
9#undef NDEBUG
10#include <cassert>
11
12int main(void) {
13 common_params params;
14
15 printf(format: "test-arg-parser: make sure there is no duplicated arguments in any examples\n\n");
16 for (int ex = 0; ex < LLAMA_EXAMPLE_COUNT; ex++) {
17 try {
18 auto ctx_arg = common_params_parser_init(params, ex: (enum llama_example)ex);
19 std::unordered_set<std::string> seen_args;
20 std::unordered_set<std::string> seen_env_vars;
21 for (const auto & opt : ctx_arg.options) {
22 // check for args duplications
23 for (const auto & arg : opt.args) {
24 if (seen_args.find(x: arg) == seen_args.end()) {
25 seen_args.insert(x: arg);
26 } else {
27 fprintf(stderr, format: "test-arg-parser: found different handlers for the same argument: %s", arg);
28 exit(status: 1);
29 }
30 }
31 // check for env var duplications
32 if (opt.env) {
33 if (seen_env_vars.find(x: opt.env) == seen_env_vars.end()) {
34 seen_env_vars.insert(x: opt.env);
35 } else {
36 fprintf(stderr, format: "test-arg-parser: found different handlers for the same env var: %s", opt.env);
37 exit(status: 1);
38 }
39 }
40 }
41 } catch (std::exception & e) {
42 printf(format: "%s\n", e.what());
43 assert(false);
44 }
45 }
46
47 auto list_str_to_char = [](std::vector<std::string> & argv) -> std::vector<char *> {
48 std::vector<char *> res;
49 for (auto & arg : argv) {
50 res.push_back(x: const_cast<char *>(arg.data()));
51 }
52 return res;
53 };
54
55 std::vector<std::string> argv;
56
57 printf(format: "test-arg-parser: test invalid usage\n\n");
58
59 // missing value
60 argv = {"binary_name", "-m"};
61 assert(false == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_COMMON));
62
63 // wrong value (int)
64 argv = {"binary_name", "-ngl", "hello"};
65 assert(false == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_COMMON));
66
67 // wrong value (enum)
68 argv = {"binary_name", "-sm", "hello"};
69 assert(false == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_COMMON));
70
71 // non-existence arg in specific example (--draft cannot be used outside llama-speculative)
72 argv = {"binary_name", "--draft", "123"};
73 assert(false == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_EMBEDDING));
74
75
76 printf(format: "test-arg-parser: test valid usage\n\n");
77
78 argv = {"binary_name", "-m", "model_file.gguf"};
79 assert(true == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_COMMON));
80 assert(params.model.path == "model_file.gguf");
81
82 argv = {"binary_name", "-t", "1234"};
83 assert(true == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_COMMON));
84 assert(params.cpuparams.n_threads == 1234);
85
86 argv = {"binary_name", "--verbose"};
87 assert(true == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_COMMON));
88 assert(params.verbosity > 1);
89
90 argv = {"binary_name", "-m", "abc.gguf", "--predict", "6789", "--batch-size", "9090"};
91 assert(true == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_COMMON));
92 assert(params.model.path == "abc.gguf");
93 assert(params.n_predict == 6789);
94 assert(params.n_batch == 9090);
95
96 // --draft cannot be used outside llama-speculative
97 argv = {"binary_name", "--draft", "123"};
98 assert(true == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_SPECULATIVE));
99 assert(params.speculative.n_max == 123);
100
101// skip this part on windows, because setenv is not supported
102#ifdef _WIN32
103 printf("test-arg-parser: skip on windows build\n");
104#else
105 printf(format: "test-arg-parser: test environment variables (valid + invalid usages)\n\n");
106
107 setenv(name: "LLAMA_ARG_THREADS", value: "blah", replace: true);
108 argv = {"binary_name"};
109 assert(false == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_COMMON));
110
111 setenv(name: "LLAMA_ARG_MODEL", value: "blah.gguf", replace: true);
112 setenv(name: "LLAMA_ARG_THREADS", value: "1010", replace: true);
113 argv = {"binary_name"};
114 assert(true == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_COMMON));
115 assert(params.model.path == "blah.gguf");
116 assert(params.cpuparams.n_threads == 1010);
117
118
119 printf(format: "test-arg-parser: test environment variables being overwritten\n\n");
120
121 setenv(name: "LLAMA_ARG_MODEL", value: "blah.gguf", replace: true);
122 setenv(name: "LLAMA_ARG_THREADS", value: "1010", replace: true);
123 argv = {"binary_name", "-m", "overwritten.gguf"};
124 assert(true == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_COMMON));
125 assert(params.model.path == "overwritten.gguf");
126 assert(params.cpuparams.n_threads == 1010);
127#endif // _WIN32
128
129 printf(format: "test-arg-parser: test curl-related functions\n\n");
130 const char * GOOD_URL = "http://ggml.ai/";
131 const char * BAD_URL = "http://ggml.ai/404";
132
133 {
134 printf(format: "test-arg-parser: test good URL\n\n");
135 auto res = common_remote_get_content(url: GOOD_URL, params: {});
136 assert(res.first == 200);
137 assert(res.second.size() > 0);
138 std::string str(res.second.data(), res.second.size());
139 assert(str.find("llama.cpp") != std::string::npos);
140 }
141
142 {
143 printf(format: "test-arg-parser: test bad URL\n\n");
144 auto res = common_remote_get_content(url: BAD_URL, params: {});
145 assert(res.first == 404);
146 }
147
148 {
149 printf(format: "test-arg-parser: test max size error\n");
150 common_remote_params params;
151 params.max_size = 1;
152 try {
153 common_remote_get_content(url: GOOD_URL, params);
154 assert(false && "it should throw an error");
155 } catch (std::exception & e) {
156 printf(format: " expected error: %s\n\n", e.what());
157 }
158 }
159
160 printf(format: "test-arg-parser: all tests OK\n\n");
161}
162