1#include "llama.h"
2#include "common.h"
3
4
5#include <cstdio>
6#include <cstring>
7#include <string>
8#include <vector>
9#include <ctype.h>
10#include <filesystem>
11
12static void print_usage(int, char ** argv) {
13 printf(format: "\nexample usage:\n");
14 printf(format: "\n %s -m model.gguf [-ngl n_gpu_layers] -embd-mode [-pooling] [-embd-norm <norm>] [prompt]\n", argv[0]);
15 printf(format: "\n");
16 printf(format: " -embd-norm: normalization type for pooled embeddings (default: 2)\n");
17 printf(format: " -1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm\n");
18 printf(format: "\n");
19}
20
21int main(int argc, char ** argv) {
22 std::string model_path;
23 std::string prompt = "Hello, my name is";
24 int ngl = 0;
25 bool embedding_mode = false;
26 bool pooling_enabled = false;
27 int32_t embd_norm = 2; // (-1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm)
28
29 {
30 int i = 1;
31 for (; i < argc; i++) {
32 if (strcmp(s1: argv[i], s2: "-m") == 0) {
33 if (i + 1 < argc) {
34 model_path = argv[++i];
35 } else {
36 print_usage(argc, argv);
37 return 1;
38 }
39 } else if (strcmp(s1: argv[i], s2: "-ngl") == 0) {
40 if (i + 1 < argc) {
41 try {
42 ngl = std::stoi(str: argv[++i]);
43 } catch (...) {
44 print_usage(argc, argv);
45 return 1;
46 }
47 } else {
48 print_usage(argc, argv);
49 return 1;
50 }
51 } else if (strcmp(s1: argv[i], s2: "-embd-mode") == 0) {
52 embedding_mode = true;
53 } else if (strcmp(s1: argv[i], s2: "-pooling") == 0) {
54 pooling_enabled = true;
55 } else if (strcmp(s1: argv[i], s2: "-embd-norm") == 0) {
56 if (i + 1 < argc) {
57 try {
58 embd_norm = std::stoi(str: argv[++i]);
59 } catch (...) {
60 print_usage(argc, argv);
61 return 1;
62 }
63 } else {
64 print_usage(argc, argv);
65 return 1;
66 }
67 } else {
68 // prompt starts here
69 break;
70 }
71 }
72
73 if (model_path.empty()) {
74 print_usage(argc, argv);
75 return 1;
76 }
77
78 if (i < argc) {
79 prompt = argv[i++];
80 for (; i < argc; i++) {
81 prompt += " ";
82 prompt += argv[i];
83 }
84 }
85 }
86
87 ggml_backend_load_all();
88 llama_model_params model_params = llama_model_default_params();
89 model_params.n_gpu_layers = ngl;
90
91 llama_model * model = llama_model_load_from_file(path_model: model_path.c_str(), params: model_params);
92
93 if (model == NULL) {
94 fprintf(stderr , format: "%s: error: unable to load model\n" , __func__);
95 return 1;
96 }
97
98 // Extract basename from model_path
99 const char * basename = strrchr(s: model_path.c_str(), c: '/');
100 basename = (basename == NULL) ? model_path.c_str() : basename + 1;
101
102 char model_name[256];
103 strncpy(dest: model_name, src: basename, n: 255);
104 model_name[255] = '\0';
105
106 char * dot = strrchr(s: model_name, c: '.');
107 if (dot != NULL && strcmp(s1: dot, s2: ".gguf") == 0) {
108 *dot = '\0';
109 }
110 printf(format: "Model name: %s\n", model_name);
111
112 const llama_vocab * vocab = llama_model_get_vocab(model);
113 const int n_prompt = -llama_tokenize(vocab, text: prompt.c_str(), text_len: prompt.size(), NULL, n_tokens_max: 0, add_special: true, parse_special: true);
114
115 std::vector<llama_token> prompt_tokens(n_prompt);
116 if (llama_tokenize(vocab, text: prompt.c_str(), text_len: prompt.size(), tokens: prompt_tokens.data(), n_tokens_max: prompt_tokens.size(), add_special: true, parse_special: true) < 0) {
117 fprintf(stderr, format: "%s: error: failed to tokenize the prompt\n", __func__);
118 return 1;
119 }
120
121 llama_context_params ctx_params = llama_context_default_params();
122 ctx_params.n_ctx = n_prompt;
123 ctx_params.n_batch = n_prompt;
124 ctx_params.no_perf = false;
125 if (embedding_mode) {
126 ctx_params.embeddings = true;
127 ctx_params.pooling_type = pooling_enabled ? LLAMA_POOLING_TYPE_MEAN : LLAMA_POOLING_TYPE_NONE;
128 ctx_params.n_ubatch = ctx_params.n_batch;
129 }
130
131 llama_context * ctx = llama_init_from_model(model, params: ctx_params);
132 if (ctx == NULL) {
133 fprintf(stderr , format: "%s: error: failed to create the llama_context\n" , __func__);
134 return 1;
135 }
136
137 printf(format: "Input prompt: \"%s\"\n", prompt.c_str());
138 printf(format: "Tokenized prompt (%d tokens): ", n_prompt);
139 for (auto id : prompt_tokens) {
140 char buf[128];
141 int n = llama_token_to_piece(vocab, token: id, buf, length: sizeof(buf), lstrip: 0, special: true);
142 if (n < 0) {
143 fprintf(stderr, format: "%s: error: failed to convert token to piece\n", __func__);
144 return 1;
145 }
146 std::string s(buf, n);
147 printf(format: "%s", s.c_str());
148 }
149 printf(format: "\n");
150
151 llama_batch batch = llama_batch_get_one(tokens: prompt_tokens.data(), n_tokens: prompt_tokens.size());
152
153 if (llama_decode(ctx, batch)) {
154 fprintf(stderr, format: "%s : failed to eval\n", __func__);
155 return 1;
156 }
157
158 float * data_ptr;
159 int data_size;
160 const char * type;
161 std::vector<float> embd_out;
162
163 if (embedding_mode) {
164 const int n_embd = llama_model_n_embd(model);
165 const int n_embd_count = pooling_enabled ? 1 : batch.n_tokens;
166 const int n_embeddings = n_embd * n_embd_count;
167 float * embeddings;
168 type = "-embeddings";
169
170 if (llama_pooling_type(ctx) != LLAMA_POOLING_TYPE_NONE) {
171 embeddings = llama_get_embeddings_seq(ctx, seq_id: 0);
172 embd_out.resize(new_size: n_embeddings);
173 printf(format: "Normalizing embeddings using norm: %d\n", embd_norm);
174 common_embd_normalize(inp: embeddings, out: embd_out.data(), n: n_embeddings, embd_norm);
175 embeddings = embd_out.data();
176 } else {
177 embeddings = llama_get_embeddings(ctx);
178 }
179
180 printf(format: "Embedding dimension: %d\n", n_embd);
181 printf(format: "\n");
182
183 // Print embeddings in the specified format
184 for (int j = 0; j < n_embd_count; j++) {
185 printf(format: "embedding %d: ", j);
186
187 // Print first 3 values
188 for (int i = 0; i < 3 && i < n_embd; i++) {
189 printf(format: "%9.6f ", embeddings[j * n_embd + i]);
190 }
191
192 printf(format: " ... ");
193
194 // Print last 3 values
195 for (int i = n_embd - 3; i < n_embd; i++) {
196 if (i >= 0) {
197 printf(format: "%9.6f ", embeddings[j * n_embd + i]);
198 }
199 }
200
201 printf(format: "\n");
202 }
203 printf(format: "\n");
204
205 printf(format: "Embeddings size: %d\n", n_embeddings);
206
207 data_ptr = embeddings;
208 data_size = n_embeddings;
209 } else {
210 float * logits = llama_get_logits_ith(ctx, i: batch.n_tokens - 1);
211 const int n_logits = llama_vocab_n_tokens(vocab);
212 type = "";
213 printf(format: "Vocab size: %d\n", n_logits);
214
215 data_ptr = logits;
216 data_size = n_logits;
217 }
218
219 std::filesystem::create_directory(p: "data");
220
221 // Save data to binary file
222 char bin_filename[512];
223 snprintf(s: bin_filename, maxlen: sizeof(bin_filename), format: "data/llamacpp-%s%s.bin", model_name, type);
224 printf(format: "Saving data to %s\n", bin_filename);
225
226 FILE * f = fopen(filename: bin_filename, modes: "wb");
227 if (f == NULL) {
228 fprintf(stderr, format: "%s: error: failed to open binary output file\n", __func__);
229 return 1;
230 }
231 fwrite(ptr: data_ptr, size: sizeof(float), n: data_size, s: f);
232 fclose(stream: f);
233
234 // Also save as text for debugging
235 char txt_filename[512];
236 snprintf(s: txt_filename, maxlen: sizeof(txt_filename), format: "data/llamacpp-%s%s.txt", model_name, type);
237 f = fopen(filename: txt_filename, modes: "w");
238 if (f == NULL) {
239 fprintf(stderr, format: "%s: error: failed to open text output file\n", __func__);
240 return 1;
241 }
242 for (int i = 0; i < data_size; i++) {
243 fprintf(stream: f, format: "%d: %.6f\n", i, data_ptr[i]);
244 }
245 fclose(stream: f);
246
247 if (!embedding_mode) {
248 printf(format: "First 10 logits: ");
249 for (int i = 0; i < 10 && i < data_size; i++) {
250 printf(format: "%.6f ", data_ptr[i]);
251 }
252 printf(format: "\n");
253
254 printf(format: "Last 10 logits: ");
255 for (int i = data_size - 10; i < data_size; i++) {
256 if (i >= 0) printf(format: "%.6f ", data_ptr[i]);
257 }
258 printf(format: "\n\n");
259 }
260
261 printf(format: "Data saved to %s\n", bin_filename);
262 printf(format: "Data saved to %s\n", txt_filename);
263
264 llama_free(ctx);
265 llama_model_free(model);
266
267 return 0;
268}
269