| 1 | #include "llama.h" |
| 2 | #include <cstdio> |
| 3 | #include <cstring> |
| 4 | #include <iostream> |
| 5 | #include <string> |
| 6 | #include <vector> |
| 7 | |
| 8 | static void print_usage(int, char ** argv) { |
| 9 | printf(format: "\nexample usage:\n" ); |
| 10 | printf(format: "\n %s -m model.gguf [-c context_size] [-ngl n_gpu_layers]\n" , argv[0]); |
| 11 | printf(format: "\n" ); |
| 12 | } |
| 13 | |
| 14 | int main(int argc, char ** argv) { |
| 15 | std::string model_path; |
| 16 | int ngl = 99; |
| 17 | int n_ctx = 2048; |
| 18 | |
| 19 | // parse command line arguments |
| 20 | for (int i = 1; i < argc; i++) { |
| 21 | try { |
| 22 | if (strcmp(s1: argv[i], s2: "-m" ) == 0) { |
| 23 | if (i + 1 < argc) { |
| 24 | model_path = argv[++i]; |
| 25 | } else { |
| 26 | print_usage(argc, argv); |
| 27 | return 1; |
| 28 | } |
| 29 | } else if (strcmp(s1: argv[i], s2: "-c" ) == 0) { |
| 30 | if (i + 1 < argc) { |
| 31 | n_ctx = std::stoi(str: argv[++i]); |
| 32 | } else { |
| 33 | print_usage(argc, argv); |
| 34 | return 1; |
| 35 | } |
| 36 | } else if (strcmp(s1: argv[i], s2: "-ngl" ) == 0) { |
| 37 | if (i + 1 < argc) { |
| 38 | ngl = std::stoi(str: argv[++i]); |
| 39 | } else { |
| 40 | print_usage(argc, argv); |
| 41 | return 1; |
| 42 | } |
| 43 | } else { |
| 44 | print_usage(argc, argv); |
| 45 | return 1; |
| 46 | } |
| 47 | } catch (std::exception & e) { |
| 48 | fprintf(stderr, format: "error: %s\n" , e.what()); |
| 49 | print_usage(argc, argv); |
| 50 | return 1; |
| 51 | } |
| 52 | } |
| 53 | if (model_path.empty()) { |
| 54 | print_usage(argc, argv); |
| 55 | return 1; |
| 56 | } |
| 57 | |
| 58 | // only print errors |
| 59 | llama_log_set(log_callback: [](enum ggml_log_level level, const char * text, void * /* user_data */) { |
| 60 | if (level >= GGML_LOG_LEVEL_ERROR) { |
| 61 | fprintf(stderr, format: "%s" , text); |
| 62 | } |
| 63 | }, user_data: nullptr); |
| 64 | |
| 65 | // load dynamic backends |
| 66 | ggml_backend_load_all(); |
| 67 | |
| 68 | // initialize the model |
| 69 | llama_model_params model_params = llama_model_default_params(); |
| 70 | model_params.n_gpu_layers = ngl; |
| 71 | |
| 72 | llama_model * model = llama_model_load_from_file(path_model: model_path.c_str(), params: model_params); |
| 73 | if (!model) { |
| 74 | fprintf(stderr , format: "%s: error: unable to load model\n" , __func__); |
| 75 | return 1; |
| 76 | } |
| 77 | |
| 78 | const llama_vocab * vocab = llama_model_get_vocab(model); |
| 79 | |
| 80 | // initialize the context |
| 81 | llama_context_params ctx_params = llama_context_default_params(); |
| 82 | ctx_params.n_ctx = n_ctx; |
| 83 | ctx_params.n_batch = n_ctx; |
| 84 | |
| 85 | llama_context * ctx = llama_init_from_model(model, params: ctx_params); |
| 86 | if (!ctx) { |
| 87 | fprintf(stderr , format: "%s: error: failed to create the llama_context\n" , __func__); |
| 88 | return 1; |
| 89 | } |
| 90 | |
| 91 | // initialize the sampler |
| 92 | llama_sampler * smpl = llama_sampler_chain_init(params: llama_sampler_chain_default_params()); |
| 93 | llama_sampler_chain_add(chain: smpl, smpl: llama_sampler_init_min_p(p: 0.05f, min_keep: 1)); |
| 94 | llama_sampler_chain_add(chain: smpl, smpl: llama_sampler_init_temp(t: 0.8f)); |
| 95 | llama_sampler_chain_add(chain: smpl, smpl: llama_sampler_init_dist(LLAMA_DEFAULT_SEED)); |
| 96 | |
| 97 | // helper function to evaluate a prompt and generate a response |
| 98 | auto generate = [&](const std::string & prompt) { |
| 99 | std::string response; |
| 100 | |
| 101 | const bool is_first = llama_memory_seq_pos_max(mem: llama_get_memory(ctx), seq_id: 0) == -1; |
| 102 | |
| 103 | // tokenize the prompt |
| 104 | const int n_prompt_tokens = -llama_tokenize(vocab, text: prompt.c_str(), text_len: prompt.size(), NULL, n_tokens_max: 0, add_special: is_first, parse_special: true); |
| 105 | std::vector<llama_token> prompt_tokens(n_prompt_tokens); |
| 106 | if (llama_tokenize(vocab, text: prompt.c_str(), text_len: prompt.size(), tokens: prompt_tokens.data(), n_tokens_max: prompt_tokens.size(), add_special: is_first, parse_special: true) < 0) { |
| 107 | GGML_ABORT("failed to tokenize the prompt\n" ); |
| 108 | } |
| 109 | |
| 110 | // prepare a batch for the prompt |
| 111 | llama_batch batch = llama_batch_get_one(tokens: prompt_tokens.data(), n_tokens: prompt_tokens.size()); |
| 112 | llama_token new_token_id; |
| 113 | while (true) { |
| 114 | // check if we have enough space in the context to evaluate this batch |
| 115 | int n_ctx = llama_n_ctx(ctx); |
| 116 | int n_ctx_used = llama_memory_seq_pos_max(mem: llama_get_memory(ctx), seq_id: 0) + 1; |
| 117 | if (n_ctx_used + batch.n_tokens > n_ctx) { |
| 118 | printf(format: "\033[0m\n" ); |
| 119 | fprintf(stderr, format: "context size exceeded\n" ); |
| 120 | exit(status: 0); |
| 121 | } |
| 122 | |
| 123 | int ret = llama_decode(ctx, batch); |
| 124 | if (ret != 0) { |
| 125 | GGML_ABORT("failed to decode, ret = %d\n" , ret); |
| 126 | } |
| 127 | |
| 128 | // sample the next token |
| 129 | new_token_id = llama_sampler_sample(smpl, ctx, idx: -1); |
| 130 | |
| 131 | // is it an end of generation? |
| 132 | if (llama_vocab_is_eog(vocab, token: new_token_id)) { |
| 133 | break; |
| 134 | } |
| 135 | |
| 136 | // convert the token to a string, print it and add it to the response |
| 137 | char buf[256]; |
| 138 | int n = llama_token_to_piece(vocab, token: new_token_id, buf, length: sizeof(buf), lstrip: 0, special: true); |
| 139 | if (n < 0) { |
| 140 | GGML_ABORT("failed to convert token to piece\n" ); |
| 141 | } |
| 142 | std::string piece(buf, n); |
| 143 | printf(format: "%s" , piece.c_str()); |
| 144 | fflush(stdout); |
| 145 | response += piece; |
| 146 | |
| 147 | // prepare the next batch with the sampled token |
| 148 | batch = llama_batch_get_one(tokens: &new_token_id, n_tokens: 1); |
| 149 | } |
| 150 | |
| 151 | return response; |
| 152 | }; |
| 153 | |
| 154 | std::vector<llama_chat_message> messages; |
| 155 | std::vector<char> formatted(llama_n_ctx(ctx)); |
| 156 | int prev_len = 0; |
| 157 | while (true) { |
| 158 | // get user input |
| 159 | printf(format: "\033[32m> \033[0m" ); |
| 160 | std::string user; |
| 161 | std::getline(is&: std::cin, str&: user); |
| 162 | |
| 163 | if (user.empty()) { |
| 164 | break; |
| 165 | } |
| 166 | |
| 167 | const char * tmpl = llama_model_chat_template(model, /* name */ nullptr); |
| 168 | |
| 169 | // add the user input to the message list and format it |
| 170 | messages.push_back(x: {.role: "user" , .content: strdup(s: user.c_str())}); |
| 171 | int new_len = llama_chat_apply_template(tmpl, chat: messages.data(), n_msg: messages.size(), add_ass: true, buf: formatted.data(), length: formatted.size()); |
| 172 | if (new_len > (int)formatted.size()) { |
| 173 | formatted.resize(new_size: new_len); |
| 174 | new_len = llama_chat_apply_template(tmpl, chat: messages.data(), n_msg: messages.size(), add_ass: true, buf: formatted.data(), length: formatted.size()); |
| 175 | } |
| 176 | if (new_len < 0) { |
| 177 | fprintf(stderr, format: "failed to apply the chat template\n" ); |
| 178 | return 1; |
| 179 | } |
| 180 | |
| 181 | // remove previous messages to obtain the prompt to generate the response |
| 182 | std::string prompt(formatted.begin() + prev_len, formatted.begin() + new_len); |
| 183 | |
| 184 | // generate a response |
| 185 | printf(format: "\033[33m" ); |
| 186 | std::string response = generate(prompt); |
| 187 | printf(format: "\n\033[0m" ); |
| 188 | |
| 189 | // add the response to the messages |
| 190 | messages.push_back(x: {.role: "assistant" , .content: strdup(s: response.c_str())}); |
| 191 | prev_len = llama_chat_apply_template(tmpl, chat: messages.data(), n_msg: messages.size(), add_ass: false, buf: nullptr, length: 0); |
| 192 | if (prev_len < 0) { |
| 193 | fprintf(stderr, format: "failed to apply the chat template\n" ); |
| 194 | return 1; |
| 195 | } |
| 196 | } |
| 197 | |
| 198 | // free resources |
| 199 | for (auto & msg : messages) { |
| 200 | free(ptr: const_cast<char *>(msg.content)); |
| 201 | } |
| 202 | llama_sampler_free(smpl); |
| 203 | llama_free(ctx); |
| 204 | llama_model_free(model); |
| 205 | |
| 206 | return 0; |
| 207 | } |
| 208 | |