1// thread safety test
2// - Loads a copy of the same model on each GPU, plus a copy on the CPU
3// - Creates n_parallel (--parallel) contexts per model
4// - Runs inference in parallel on each context
5
6#include <array>
7#include <thread>
8#include <vector>
9#include <atomic>
10#include "llama.h"
11#include "arg.h"
12#include "common.h"
13#include "log.h"
14#include "sampling.h"
15
16int main(int argc, char ** argv) {
17 common_params params;
18
19 if (!common_params_parse(argc, argv, params, ex: LLAMA_EXAMPLE_COMMON)) {
20 return 1;
21 }
22
23 common_init();
24
25 llama_backend_init();
26 llama_numa_init(numa: params.numa);
27
28 LOG_INF("%s\n", common_params_get_system_info(params).c_str());
29
30 //llama_log_set([](ggml_log_level level, const char * text, void * /*user_data*/) {
31 // if (level == GGML_LOG_LEVEL_ERROR) {
32 // common_log_add(common_log_main(), level, "%s", text);
33 // }
34 //}, NULL);
35
36 auto cparams = common_context_params_to_llama(params);
37
38 // each context has a single sequence
39 cparams.n_seq_max = 1;
40
41 int dev_count = ggml_backend_dev_count();
42 std::vector<std::array<ggml_backend_dev_t, 2>> gpus;
43 for (int i = 0; i < dev_count; ++i) {
44 auto * dev = ggml_backend_dev_get(index: i);
45 if (dev && ggml_backend_dev_type(device: dev) == GGML_BACKEND_DEVICE_TYPE_GPU) {
46 gpus.push_back(x: {dev, nullptr});
47 }
48 }
49 const int gpu_dev_count = (int)gpus.size();
50 const int num_models = gpu_dev_count + 1 + 1; // GPUs + 1 CPU model + 1 layer split
51 //const int num_models = std::max(1, gpu_dev_count);
52 const int num_contexts = std::max(a: 1, b: params.n_parallel);
53
54 std::vector<llama_model_ptr> models;
55 std::vector<std::thread> threads;
56 std::atomic<bool> failed = false;
57
58 for (int m = 0; m < num_models; ++m) {
59 auto mparams = common_model_params_to_llama(params);
60
61 if (m < gpu_dev_count) {
62 mparams.split_mode = LLAMA_SPLIT_MODE_NONE;
63 mparams.devices = gpus[m].data();
64 } else if (m == gpu_dev_count) {
65 mparams.split_mode = LLAMA_SPLIT_MODE_NONE;
66 mparams.main_gpu = -1; // CPU model
67 } else {
68 mparams.split_mode = LLAMA_SPLIT_MODE_LAYER;
69 }
70
71 llama_model * model = llama_model_load_from_file(path_model: params.model.path.c_str(), params: mparams);
72 if (model == NULL) {
73 LOG_ERR("%s: failed to load model '%s'\n", __func__, params.model.path.c_str());
74 return 1;
75 }
76
77 models.emplace_back(args&: model);
78 }
79
80 for (int m = 0; m < num_models; ++m) {
81 auto * model = models[m].get();
82 for (int c = 0; c < num_contexts; ++c) {
83 threads.emplace_back(args: [&, m, c, model]() {
84 LOG_INF("Creating context %d/%d for model %d/%d\n", c + 1, num_contexts, m + 1, num_models);
85
86 llama_context_ptr ctx { llama_init_from_model(model, params: cparams) };
87 if (ctx == NULL) {
88 LOG_ERR("failed to create context\n");
89 failed.store(i: true);
90 return;
91 }
92
93 std::unique_ptr<common_sampler, decltype(&common_sampler_free)> sampler { common_sampler_init(model, params: params.sampling), common_sampler_free };
94 if (sampler == NULL) {
95 LOG_ERR("failed to create sampler\n");
96 failed.store(i: true);
97 return;
98 }
99
100 llama_batch batch = {};
101 {
102 auto prompt = common_tokenize(ctx: ctx.get(), text: params.prompt, add_special: true);
103 if (prompt.empty()) {
104 LOG_ERR("failed to tokenize prompt\n");
105 failed.store(i: true);
106 return;
107 }
108 batch = llama_batch_get_one(tokens: prompt.data(), n_tokens: prompt.size());
109 if (llama_decode(ctx: ctx.get(), batch)) {
110 LOG_ERR("failed to decode prompt\n");
111 failed.store(i: true);
112 return;
113 }
114 }
115
116 const auto * vocab = llama_model_get_vocab(model);
117 std::string result = params.prompt;
118
119 for (int i = 0; i < params.n_predict; i++) {
120 llama_token token;
121 if (batch.n_tokens > 0) {
122 token = common_sampler_sample(gsmpl: sampler.get(), ctx: ctx.get(), idx: batch.n_tokens - 1);
123 } else {
124 token = llama_vocab_bos(vocab);
125 }
126
127 result += common_token_to_piece(ctx: ctx.get(), token);
128
129 if (llama_vocab_is_eog(vocab, token)) {
130 break;
131 }
132
133 batch = llama_batch_get_one(tokens: &token, n_tokens: 1);
134
135 int ret = llama_decode(ctx: ctx.get(), batch);
136 if (ret == 1 && i > 0) {
137 LOG_INF("Context full, stopping generation.\n");
138 break;
139 }
140
141 if (ret != 0) {
142 LOG_ERR("Model %d/%d, Context %d/%d: failed to decode\n", m + 1, num_models, c + 1, num_contexts);
143 failed.store(i: true);
144 return;
145 }
146 }
147
148 LOG_INF("Model %d/%d, Context %d/%d: %s\n\n", m + 1, num_models, c + 1, num_contexts, result.c_str());
149 });
150 }
151 }
152
153 for (auto & thread : threads) {
154 thread.join();
155 }
156
157 if (failed) {
158 LOG_ERR("One or more threads failed.\n");
159 return 1;
160 }
161
162 LOG_INF("All threads finished without errors.\n");
163 return 0;
164}
165