1// A basic application simulating a server with multiple clients.
2// The clients submit requests to the server and they are processed in parallel.
3
4#include "arg.h"
5#include "common.h"
6#include "sampling.h"
7#include "log.h"
8#include "llama.h"
9
10#include <cmath>
11#include <cstdio>
12#include <string>
13#include <vector>
14#include <ctime>
15#include <algorithm>
16
17// trim whitespace from the beginning and end of a string
18static std::string trim(const std::string & str) {
19 size_t start = 0;
20 size_t end = str.size();
21
22 while (start < end && isspace(str[start])) {
23 start += 1;
24 }
25
26 while (end > start && isspace(str[end - 1])) {
27 end -= 1;
28 }
29
30 return str.substr(pos: start, n: end - start);
31}
32
33static std::string k_system =
34R"(Transcript of a never ending dialog, where the User interacts with an Assistant.
35The Assistant is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision.
36
37User:
38Recommend a nice restaurant in the area.
39Assistant:
40I recommend the restaurant "The Golden Duck". It is a 5 star restaurant with a great view of the city. The food is delicious and the service is excellent. The prices are reasonable and the portions are generous. The restaurant is located at 123 Main Street, New York, NY 10001. The phone number is (212) 555-1234. The hours are Monday through Friday from 11:00 am to 10:00 pm. The restaurant is closed on Saturdays and Sundays.
41User:
42Who is Richard Feynman?
43Assistant:
44Richard Feynman was an American physicist who is best known for his work in quantum mechanics and particle physics. He was awarded the Nobel Prize in Physics in 1965 for his contributions to the development of quantum electrodynamics. He was a popular lecturer and author, and he wrote several books, including "Surely You're Joking, Mr. Feynman!" and "What Do You Care What Other People Think?".
45)";
46
47static std::vector<std::string> k_questions = {
48 "What is the tallest mountain in the world?",
49 "Who was the first person to win two Nobel Prizes?",
50 "Which country invented paper?",
51 "What organ is primarily responsible for pumping blood throughout the body?",
52 "Which planet is known for its prominent ring system?",
53 "Who directed the movie 'Inception'?",
54 "What is the freezing point of water in Fahrenheit?",
55 "Which animal is known to have the longest lifespan?",
56 "What language has the most native speakers worldwide?",
57 "What is the capital city of Canada?",
58 "Who is credited with inventing the World Wide Web?",
59 "Which metal is liquid at room temperature?",
60 "What is the term for an animal that eats both plants and meat?",
61 "Who painted 'The Starry Night'?",
62 "What gas do humans exhale that plants use for photosynthesis?",
63 "What year did World War II end?",
64 "Which continent has the most countries?",
65 "Who wrote the novel 'Frankenstein'?",
66 "What does DNA stand for?",
67 "What is the main ingredient in traditional Japanese miso soup?"
68};
69
70static std::vector<std::string> k_answers = {
71 "The tallest mountain in the world is Mount Everest.",
72 "Marie Curie was the first person to win two Nobel Prizes.",
73 "Paper was invented in China.",
74 "The heart is the organ responsible for pumping blood.",
75 "Saturn is known for its prominent ring system.",
76 "Christopher Nolan directed the movie 'Inception'.",
77 "The freezing point of water in Fahrenheit is 32°F.",
78 "The bowhead whale is known to have the longest lifespan among mammals.",
79 "Mandarin Chinese has the most native speakers in the world.",
80 "The capital city of Canada is Ottawa.",
81 "Tim Berners-Lee is credited with inventing the World Wide Web.",
82 "Mercury is the metal that is liquid at room temperature.",
83 "An animal that eats both plants and meat is called an omnivore.",
84 "'The Starry Night' was painted by Vincent van Gogh.",
85 "Humans exhale carbon dioxide, which plants use in photosynthesis.",
86 "World War II ended in 1945.",
87 "Africa is the continent with the most countries.",
88 "The novel 'Frankenstein' was written by Mary Shelley.",
89 "DNA stands for Deoxyribonucleic Acid.",
90 "The main ingredient in traditional Japanese miso soup is fermented soybean paste."
91};
92
93static std::vector<std::string> k_prompts = {
94 "What is the meaning of life?",
95 "Tell me an interesting fact about llamas.",
96 "What is the best way to cook a steak?",
97 "Are you familiar with the Special Theory of Relativity and can you explain it to me?",
98 "Recommend some interesting books to read.",
99 "What is the best way to learn a new language?",
100 "How to get a job at Google?",
101 "If you could have any superpower, what would it be?",
102 "I want to learn how to play the piano. What would be the best way to do it?",
103};
104
105struct client {
106 ~client() {
107 if (smpl) {
108 common_sampler_free(gsmpl: smpl);
109 }
110 }
111
112 int32_t id = 0;
113
114 llama_seq_id seq_id = -1;
115
116 llama_token sampled;
117
118 int64_t t_start_prompt;
119 int64_t t_start_gen;
120
121 int32_t n_past = 0;
122 int32_t n_prompt = 0;
123 int32_t n_decoded = 0;
124 int32_t i_batch = -1;
125
126 std::string input;
127 std::string prompt;
128 std::string response;
129
130 struct common_sampler * smpl = nullptr;
131};
132
133static void print_date_time() {
134 std::time_t current_time = std::time(timer: nullptr);
135 std::tm* local_time = std::localtime(timer: &current_time);
136 char buffer[80];
137 strftime(s: buffer, maxsize: sizeof(buffer), format: "%Y-%m-%d %H:%M:%S", tp: local_time);
138
139 LOG_INF("\n");
140 LOG_INF("\033[35mrun parameters as of %s\033[0m\n", buffer);
141 LOG_INF("\n");
142}
143
144// Define a split string function to ...
145static std::vector<std::string> split_string(const std::string& input, char delimiter) {
146 std::vector<std::string> tokens;
147 std::istringstream stream(input);
148 std::string token;
149 while (std::getline(in&: stream, str&: token, delim: delimiter)) {
150 tokens.push_back(x: token);
151 }
152 return tokens;
153}
154
155int main(int argc, char ** argv) {
156 srand(seed: 1234);
157
158 common_params params;
159
160 params.n_predict = 128;
161 params.n_junk = 1;
162
163 if (!common_params_parse(argc, argv, params, ex: LLAMA_EXAMPLE_PARALLEL)) {
164 return 1;
165 }
166
167 common_init();
168
169 // number of simultaneous "clients" to simulate
170 const int32_t n_clients = params.n_parallel;
171
172 // dedicate one sequence to the system prompt
173 params.n_parallel += 1;
174
175 // requests to simulate
176 const int32_t n_seq = params.n_sequences;
177
178 // insert new requests as soon as the previous one is done
179 const bool cont_batching = params.cont_batching;
180
181 // is the system prompt shared in the cache
182 const bool is_sp_shared = params.is_pp_shared;
183
184 // extra text to insert in each client's prompt in order to make it larger
185 const int32_t n_junk = std::max(a: 1, b: params.n_junk);
186
187 // signed seed, use negative values to indicate different seeds for the different clients
188 const int32_t & sseed = params.sampling.seed;
189
190 // init llama.cpp
191 llama_backend_init();
192 llama_numa_init(numa: params.numa);
193
194 // load the target model
195 common_init_result llama_init = common_init_from_params(params);
196
197 llama_model * model = llama_init.model.get();
198 llama_context * ctx = llama_init.context.get();
199
200 auto * mem = llama_get_memory(ctx);
201
202 const llama_vocab * vocab = llama_model_get_vocab(model);
203
204 // load the prompts from an external file if there are any
205 if (params.prompt.empty()) {
206 LOG_INF("\033[32mNo new questions so proceed with build-in defaults.\033[0m\n");
207 } else {
208 // Output each line of the input params.prompts vector and copy to k_prompts
209 int index = 0;
210 LOG_INF("\033[32mNow printing the external prompt file %s\033[0m\n\n", params.prompt_file.c_str());
211
212 std::vector<std::string> prompts = split_string(input: params.prompt, delimiter: '\n');
213 for (const auto& prompt : prompts) {
214 k_prompts.resize(new_size: index + 1);
215 k_prompts[index] = prompt;
216 index++;
217 LOG_INF("%3d prompt: %s\n", index, prompt.c_str());
218 }
219 }
220
221 LOG_INF("\n\n");
222
223 const int n_ctx = llama_n_ctx(ctx);
224
225 if (sseed >= 0) {
226 LOG_INF("%s: initializing all samplers with the same RNG seed: %d (use a negative seed to have different seeds)\n", __func__, sseed);
227 } else {
228 LOG_INF("%s: initializing samplers with different RNG seeds, starting from %d\n", __func__, sseed);
229 }
230
231 std::vector<client> clients(n_clients);
232 for (size_t i = 0; i < clients.size(); ++i) {
233 auto & client = clients[i];
234 client.id = i;
235 client.smpl = common_sampler_init(model, params: params.sampling);
236
237 if (sseed < 0) {
238 params.sampling.seed--;
239 }
240 }
241
242 std::vector<llama_token> tokens_system;
243
244 tokens_system = common_tokenize(ctx, text: k_system, add_special: true);
245 const int32_t n_tokens_system = tokens_system.size();
246
247 llama_seq_id g_seq_id = 0;
248
249 // the max batch size is as large as the context to handle cases where we get very long input prompt from multiple
250 // users. regardless of the size, the main loop will chunk the batch into a maximum of params.n_batch tokens at a time
251 llama_batch batch = llama_batch_init(n_tokens: n_ctx, embd: 0, n_seq_max: 1);
252
253 int32_t n_total_prompt = 0;
254 int32_t n_total_gen = 0;
255 int32_t n_cache_miss = 0;
256
257 const auto t_main_start = ggml_time_us();
258
259 LOG_INF("%s: Simulating parallel requests from clients:\n", __func__);
260 LOG_INF("%s: n_parallel = %d, n_sequences = %d, cont_batching = %d, system tokens = %d\n", __func__, n_clients, n_seq, cont_batching, n_tokens_system);
261 LOG_INF("\n");
262
263 if (is_sp_shared) {
264 LOG_INF("%s: Evaluating the system prompt ...\n", __func__);
265
266 for (int32_t i = 0; i < n_tokens_system; ++i) {
267 common_batch_add(batch, id: tokens_system[i], pos: i, seq_ids: { 0 }, logits: false);
268 }
269
270 if (llama_decode(ctx, batch) != 0) {
271 LOG_ERR("%s: llama_decode() failed\n", __func__);
272 return 1;
273 }
274
275 // assign the system KV cache to all parallel sequences
276 for (int32_t i = 1; i <= n_clients; ++i) {
277 llama_memory_seq_cp(mem, seq_id_src: 0, seq_id_dst: i, p0: -1, p1: -1);
278 }
279
280 LOG_INF("\n");
281 }
282
283 LOG_INF("Processing requests ...\n\n");
284
285 while (true) {
286 common_batch_clear(batch);
287
288 // decode any currently ongoing sequences
289 for (auto & client : clients) {
290 if (client.seq_id == -1) {
291 continue;
292 }
293
294 client.i_batch = batch.n_tokens;
295
296 common_batch_add(batch, id: client.sampled, pos: client.n_past++, seq_ids: { client.id + 1 }, logits: true);
297
298 client.n_decoded += 1;
299 }
300
301 if (batch.n_tokens == 0) {
302 // all sequences have ended - clear the entire KV cache
303 for (int i = 1; i <= n_clients; ++i) {
304 llama_memory_seq_rm(mem, seq_id: i, p0: -1, p1: -1);
305 // but keep the system prompt
306 llama_memory_seq_cp(mem, seq_id_src: 0, seq_id_dst: i, p0: -1, p1: -1);
307 }
308
309 LOG_INF("%s: clearing the KV cache\n", __func__);
310 }
311
312 // insert new sequences for decoding
313 if (cont_batching || batch.n_tokens == 0) {
314 for (auto & client : clients) {
315 if (client.seq_id == -1 && g_seq_id < n_seq) {
316 client.seq_id = g_seq_id;
317
318 client.t_start_prompt = ggml_time_us();
319 client.t_start_gen = 0;
320
321 client.input = k_prompts[rand() % k_prompts.size()];
322 client.response = "";
323
324 // construct the prompt:
325 // [system prompt] + [junk] + [user prompt]
326 client.n_past = 0;
327 client.prompt = "";
328 if (is_sp_shared) {
329 client.n_past = n_tokens_system;
330 } else {
331 client.prompt += k_system;
332 }
333
334 const int n_junk_cur = rand() % n_junk;
335
336 for (int i = 0; i < n_junk_cur; ++i) {
337 const int r = rand() % k_questions.size();
338 client.prompt += "User:\n" + k_questions[r] + "\nAssistant:\n " + k_answers[r] + "\n";
339 }
340 client.prompt += "User:\n" + client.input + "\nAssistant:\n";
341
342 common_sampler_reset(gsmpl: client.smpl);
343
344 // do not prepend BOS because we have a system prompt!
345 std::vector<llama_token> tokens_prompt;
346 tokens_prompt = common_tokenize(ctx, text: client.prompt, add_special: false);
347
348 for (size_t i = 0; i < tokens_prompt.size(); ++i) {
349 common_batch_add(batch, id: tokens_prompt[i], pos: client.n_past++, seq_ids: { client.id + 1 }, logits: false);
350 }
351
352 // extract the logits only for the last token
353 if (batch.n_tokens > 0) {
354 batch.logits[batch.n_tokens - 1] = true;
355 }
356
357 client.n_prompt = tokens_prompt.size();
358 client.n_decoded = 0;
359 client.i_batch = batch.n_tokens - 1;
360
361 LOG_INF("\033[31mClient %3d, seq %4d, junk = %4d, prompt = %d, started decoding ...\033[0m\n", client.id, client.seq_id, n_junk_cur, client.n_prompt);
362
363 g_seq_id += 1;
364
365 // insert new requests one-by-one
366 //if (cont_batching) {
367 // break;
368 //}
369 }
370 }
371 }
372
373 if (batch.n_tokens == 0) {
374 break;
375 }
376
377 // process in chunks of params.n_batch
378 int32_t n_batch = params.n_batch;
379
380 int32_t i_next = 0;
381
382 for (int32_t i = 0; i < batch.n_tokens; i = i_next) {
383 // experiment: process in powers of 2
384 //if (i + n_batch > (int32_t) batch.n_tokens && n_batch > 32) {
385 // n_batch /= 2;
386 // i -= n_batch;
387 // continue;
388 //}
389
390 const int32_t n_tokens = std::min(a: n_batch, b: batch.n_tokens - i);
391
392 llama_batch batch_view = {
393 .n_tokens: n_tokens,
394 .token: batch.token + i,
395 .embd: nullptr,
396 .pos: batch.pos + i,
397 .n_seq_id: batch.n_seq_id + i,
398 .seq_id: batch.seq_id + i,
399 .logits: batch.logits + i,
400 };
401
402 const int ret = llama_decode(ctx, batch: batch_view);
403 if (ret != 0) {
404 if (n_batch == 1 || ret < 0) {
405 // if you get here, it means the KV cache is full - try increasing it via the context size
406 LOG_ERR("%s : failed to decode the batch, n_batch = %d, ret = %d\n", __func__, n_batch, ret);
407 return 1;
408 }
409
410 LOG_WRN("%s : failed to decode the batch, retrying with n_batch = %d\n", __func__, n_batch / 2);
411
412 n_cache_miss += 1;
413
414 // retry with half the batch size to try to find a free slot in the KV cache
415 n_batch /= 2;
416
417 continue;
418 }
419
420 LOG_DBG("%s : decoded batch of %d tokens\n", __func__, n_tokens);
421
422 // move the head of the batch forward with the number of tokens we just processed
423 i_next = i + n_tokens;
424
425 // on successful decode, restore the original batch size
426 n_batch = params.n_batch;
427
428 for (auto & client : clients) {
429 if (client.i_batch < (int) i || client.i_batch >= (int) (i + n_tokens)) {
430 continue;
431 }
432
433 //printf("client %d, seq %d, token %d, pos %d, batch %d\n",
434 // client.id, client.seq_id, client.sampled, client.n_decoded, client.i_batch);
435
436 const llama_token id = common_sampler_sample(gsmpl: client.smpl, ctx, idx: client.i_batch - i);
437
438 common_sampler_accept(gsmpl: client.smpl, token: id, accept_grammar: true);
439
440 if (client.n_decoded == 1) {
441 // start measuring generation time after the first token to make sure all concurrent clients
442 // have their prompt already processed
443 client.t_start_gen = ggml_time_us();
444 }
445
446 const std::string token_str = common_token_to_piece(ctx, token: id);
447
448 client.response += token_str;
449 client.sampled = id;
450
451 //printf("client %d, seq %d, token %d, pos %d, batch %d: %s\n",
452 // client.id, client.seq_id, id, client.n_decoded, client.i_batch, token_str.c_str());
453
454 if (client.n_decoded > 2 &&
455 (llama_vocab_is_eog(vocab, token: id) ||
456 (params.n_predict > 0 && client.n_decoded >= params.n_predict) ||
457 client.response.find(s: "User:") != std::string::npos)) {
458 // basic reverse prompt
459 const size_t pos = client.response.find(s: "User:");
460 if (pos != std::string::npos) {
461 client.response = client.response.substr(pos: 0, n: pos);
462 }
463
464 // delete only the generated part of the sequence, i.e. keep the system prompt in the cache
465 llama_memory_seq_rm(mem, seq_id: client.id + 1, p0: -1, p1: -1);
466 llama_memory_seq_cp(mem, seq_id_src: 0, seq_id_dst: client.id + 1, p0: -1, p1: -1);
467
468 const auto t_main_end = ggml_time_us();
469
470 LOG_INF("\033[31mClient %3d, seq %3d/%3d, prompt %4d t, response %4d t, time %5.2f s, speed %5.2f t/s, cache miss %d \033[0m \n\nInput: %s\n\033[35mResponse: %s\033[0m\n\n",
471 client.id, client.seq_id, n_seq, client.n_prompt, client.n_decoded,
472 (t_main_end - client.t_start_prompt) / 1e6,
473 (double) (client.n_prompt + client.n_decoded) / (t_main_end - client.t_start_prompt) * 1e6,
474 n_cache_miss,
475 ::trim(client.input).c_str(),
476 ::trim(client.response).c_str());
477
478 n_total_prompt += client.n_prompt;
479 n_total_gen += client.n_decoded;
480
481 client.seq_id = -1;
482 }
483
484 client.i_batch = -1;
485 }
486 }
487 }
488
489 const auto t_main_end = ggml_time_us();
490
491 print_date_time();
492
493 LOG_INF("%s: n_parallel = %d, n_sequences = %d, cont_batching = %d, system tokens = %d\n", __func__, n_clients, n_seq, cont_batching, n_tokens_system);
494 if (params.prompt_file.empty()) {
495 params.prompt_file = "used built-in defaults";
496 }
497 LOG_INF("External prompt file: \033[32m%s\033[0m\n", params.prompt_file.c_str());
498 LOG_INF("Model and path used: \033[32m%s\033[0m\n\n", params.model.path.c_str());
499
500 LOG_INF("Total prompt tokens: %6d, speed: %5.2f t/s\n", n_total_prompt, (double) (n_total_prompt ) / (t_main_end - t_main_start) * 1e6);
501 LOG_INF("Total gen tokens: %6d, speed: %5.2f t/s\n", n_total_gen, (double) (n_total_gen ) / (t_main_end - t_main_start) * 1e6);
502 LOG_INF("Total speed (AVG): %6s speed: %5.2f t/s\n", "", (double) (n_total_prompt + n_total_gen) / (t_main_end - t_main_start) * 1e6);
503 LOG_INF("Cache misses: %6d\n", n_cache_miss);
504
505 LOG_INF("\n");
506
507 // TODO: print sampling/grammar timings for all clients
508 llama_perf_context_print(ctx);
509
510 llama_batch_free(batch);
511
512 llama_backend_free();
513
514 LOG("\n\n");
515
516 return 0;
517}
518