| 1 | #include "arg.h" |
| 2 | #include "common.h" |
| 3 | |
| 4 | #include <string> |
| 5 | #include <vector> |
| 6 | #include <sstream> |
| 7 | #include <unordered_set> |
| 8 | |
| 9 | #undef NDEBUG |
| 10 | #include <cassert> |
| 11 | |
| 12 | int main(void) { |
| 13 | common_params params; |
| 14 | |
| 15 | printf(format: "test-arg-parser: make sure there is no duplicated arguments in any examples\n\n" ); |
| 16 | for (int ex = 0; ex < LLAMA_EXAMPLE_COUNT; ex++) { |
| 17 | try { |
| 18 | auto ctx_arg = common_params_parser_init(params, ex: (enum llama_example)ex); |
| 19 | std::unordered_set<std::string> seen_args; |
| 20 | std::unordered_set<std::string> seen_env_vars; |
| 21 | for (const auto & opt : ctx_arg.options) { |
| 22 | // check for args duplications |
| 23 | for (const auto & arg : opt.args) { |
| 24 | if (seen_args.find(x: arg) == seen_args.end()) { |
| 25 | seen_args.insert(x: arg); |
| 26 | } else { |
| 27 | fprintf(stderr, format: "test-arg-parser: found different handlers for the same argument: %s" , arg); |
| 28 | exit(status: 1); |
| 29 | } |
| 30 | } |
| 31 | // check for env var duplications |
| 32 | if (opt.env) { |
| 33 | if (seen_env_vars.find(x: opt.env) == seen_env_vars.end()) { |
| 34 | seen_env_vars.insert(x: opt.env); |
| 35 | } else { |
| 36 | fprintf(stderr, format: "test-arg-parser: found different handlers for the same env var: %s" , opt.env); |
| 37 | exit(status: 1); |
| 38 | } |
| 39 | } |
| 40 | } |
| 41 | } catch (std::exception & e) { |
| 42 | printf(format: "%s\n" , e.what()); |
| 43 | assert(false); |
| 44 | } |
| 45 | } |
| 46 | |
| 47 | auto list_str_to_char = [](std::vector<std::string> & argv) -> std::vector<char *> { |
| 48 | std::vector<char *> res; |
| 49 | for (auto & arg : argv) { |
| 50 | res.push_back(x: const_cast<char *>(arg.data())); |
| 51 | } |
| 52 | return res; |
| 53 | }; |
| 54 | |
| 55 | std::vector<std::string> argv; |
| 56 | |
| 57 | printf(format: "test-arg-parser: test invalid usage\n\n" ); |
| 58 | |
| 59 | // missing value |
| 60 | argv = {"binary_name" , "-m" }; |
| 61 | assert(false == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_COMMON)); |
| 62 | |
| 63 | // wrong value (int) |
| 64 | argv = {"binary_name" , "-ngl" , "hello" }; |
| 65 | assert(false == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_COMMON)); |
| 66 | |
| 67 | // wrong value (enum) |
| 68 | argv = {"binary_name" , "-sm" , "hello" }; |
| 69 | assert(false == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_COMMON)); |
| 70 | |
| 71 | // non-existence arg in specific example (--draft cannot be used outside llama-speculative) |
| 72 | argv = {"binary_name" , "--draft" , "123" }; |
| 73 | assert(false == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_EMBEDDING)); |
| 74 | |
| 75 | |
| 76 | printf(format: "test-arg-parser: test valid usage\n\n" ); |
| 77 | |
| 78 | argv = {"binary_name" , "-m" , "model_file.gguf" }; |
| 79 | assert(true == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_COMMON)); |
| 80 | assert(params.model.path == "model_file.gguf" ); |
| 81 | |
| 82 | argv = {"binary_name" , "-t" , "1234" }; |
| 83 | assert(true == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_COMMON)); |
| 84 | assert(params.cpuparams.n_threads == 1234); |
| 85 | |
| 86 | argv = {"binary_name" , "--verbose" }; |
| 87 | assert(true == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_COMMON)); |
| 88 | assert(params.verbosity > 1); |
| 89 | |
| 90 | argv = {"binary_name" , "-m" , "abc.gguf" , "--predict" , "6789" , "--batch-size" , "9090" }; |
| 91 | assert(true == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_COMMON)); |
| 92 | assert(params.model.path == "abc.gguf" ); |
| 93 | assert(params.n_predict == 6789); |
| 94 | assert(params.n_batch == 9090); |
| 95 | |
| 96 | // --draft cannot be used outside llama-speculative |
| 97 | argv = {"binary_name" , "--draft" , "123" }; |
| 98 | assert(true == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_SPECULATIVE)); |
| 99 | assert(params.speculative.n_max == 123); |
| 100 | |
| 101 | // skip this part on windows, because setenv is not supported |
| 102 | #ifdef _WIN32 |
| 103 | printf("test-arg-parser: skip on windows build\n" ); |
| 104 | #else |
| 105 | printf(format: "test-arg-parser: test environment variables (valid + invalid usages)\n\n" ); |
| 106 | |
| 107 | setenv(name: "LLAMA_ARG_THREADS" , value: "blah" , replace: true); |
| 108 | argv = {"binary_name" }; |
| 109 | assert(false == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_COMMON)); |
| 110 | |
| 111 | setenv(name: "LLAMA_ARG_MODEL" , value: "blah.gguf" , replace: true); |
| 112 | setenv(name: "LLAMA_ARG_THREADS" , value: "1010" , replace: true); |
| 113 | argv = {"binary_name" }; |
| 114 | assert(true == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_COMMON)); |
| 115 | assert(params.model.path == "blah.gguf" ); |
| 116 | assert(params.cpuparams.n_threads == 1010); |
| 117 | |
| 118 | |
| 119 | printf(format: "test-arg-parser: test environment variables being overwritten\n\n" ); |
| 120 | |
| 121 | setenv(name: "LLAMA_ARG_MODEL" , value: "blah.gguf" , replace: true); |
| 122 | setenv(name: "LLAMA_ARG_THREADS" , value: "1010" , replace: true); |
| 123 | argv = {"binary_name" , "-m" , "overwritten.gguf" }; |
| 124 | assert(true == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_COMMON)); |
| 125 | assert(params.model.path == "overwritten.gguf" ); |
| 126 | assert(params.cpuparams.n_threads == 1010); |
| 127 | #endif // _WIN32 |
| 128 | |
| 129 | printf(format: "test-arg-parser: test curl-related functions\n\n" ); |
| 130 | const char * GOOD_URL = "http://ggml.ai/" ; |
| 131 | const char * BAD_URL = "http://ggml.ai/404" ; |
| 132 | |
| 133 | { |
| 134 | printf(format: "test-arg-parser: test good URL\n\n" ); |
| 135 | auto res = common_remote_get_content(url: GOOD_URL, params: {}); |
| 136 | assert(res.first == 200); |
| 137 | assert(res.second.size() > 0); |
| 138 | std::string str(res.second.data(), res.second.size()); |
| 139 | assert(str.find("llama.cpp" ) != std::string::npos); |
| 140 | } |
| 141 | |
| 142 | { |
| 143 | printf(format: "test-arg-parser: test bad URL\n\n" ); |
| 144 | auto res = common_remote_get_content(url: BAD_URL, params: {}); |
| 145 | assert(res.first == 404); |
| 146 | } |
| 147 | |
| 148 | { |
| 149 | printf(format: "test-arg-parser: test max size error\n" ); |
| 150 | common_remote_params params; |
| 151 | params.max_size = 1; |
| 152 | try { |
| 153 | common_remote_get_content(url: GOOD_URL, params); |
| 154 | assert(false && "it should throw an error" ); |
| 155 | } catch (std::exception & e) { |
| 156 | printf(format: " expected error: %s\n\n" , e.what()); |
| 157 | } |
| 158 | } |
| 159 | |
| 160 | printf(format: "test-arg-parser: all tests OK\n\n" ); |
| 161 | } |
| 162 | |