| 1 | #include "arg.h" |
| 2 | #include "common.h" |
| 3 | #include "log.h" |
| 4 | #include "llama.h" |
| 5 | |
| 6 | #include <algorithm> |
| 7 | #include <fstream> |
| 8 | #include <iostream> // TODO: remove me |
| 9 | |
| 10 | static void print_usage(int, char ** argv) { |
| 11 | LOG("\nexample usage:\n" ); |
| 12 | LOG("\n %s --model ./models/bge-base-en-v1.5-f16.gguf --top-k 3 --context-file README.md --context-file License --chunk-size 100 --chunk-separator .\n" , argv[0]); |
| 13 | LOG("\n" ); |
| 14 | } |
| 15 | |
| 16 | struct chunk { |
| 17 | // filename |
| 18 | std::string filename; |
| 19 | // original file position |
| 20 | size_t filepos; |
| 21 | // original text data |
| 22 | std::string textdata; |
| 23 | // tokenized text data |
| 24 | std::vector<llama_token> tokens; |
| 25 | // embedding |
| 26 | std::vector<float> embedding; |
| 27 | }; |
| 28 | |
| 29 | // chunk file data to chunks of size >= chunk_size |
| 30 | // chunk_separator is the separator between chunks |
| 31 | static std::vector<chunk> chunk_file(const std::string & filename, int chunk_size, const std::string & chunk_separator) { |
| 32 | std::vector<chunk> chunks; |
| 33 | std::ifstream f(filename.c_str()); |
| 34 | |
| 35 | if (!f.is_open()) { |
| 36 | LOG_ERR("could not open file %s\n" , filename.c_str()); |
| 37 | return chunks; |
| 38 | } |
| 39 | |
| 40 | chunk current_chunk; |
| 41 | char buffer[1024]; |
| 42 | int64_t filepos = 0; |
| 43 | std::string current; |
| 44 | while (f.read(s: buffer, n: 1024)) { |
| 45 | current += std::string(buffer, f.gcount()); |
| 46 | size_t pos; |
| 47 | while ((pos = current.find(str: chunk_separator)) != std::string::npos) { |
| 48 | current_chunk.textdata += current.substr(pos: 0, n: pos + chunk_separator.size()); |
| 49 | if ((int) current_chunk.textdata.size() > chunk_size) { |
| 50 | // save chunk |
| 51 | current_chunk.filepos = filepos; |
| 52 | current_chunk.filename = filename; |
| 53 | chunks.push_back(x: current_chunk); |
| 54 | // update filepos |
| 55 | filepos += (int) current_chunk.textdata.size(); |
| 56 | // reset current_chunk |
| 57 | current_chunk = chunk(); |
| 58 | } |
| 59 | current = current.substr(pos: pos + chunk_separator.size()); |
| 60 | } |
| 61 | |
| 62 | } |
| 63 | // add leftover data to last chunk |
| 64 | if (current_chunk.textdata.size() > 0) { |
| 65 | if (chunks.empty()) { |
| 66 | current_chunk.filepos = filepos; |
| 67 | current_chunk.filename = filename; |
| 68 | chunks.push_back(x: current_chunk); |
| 69 | } else { |
| 70 | chunks.back().textdata += current_chunk.textdata; |
| 71 | } |
| 72 | } |
| 73 | f.close(); |
| 74 | return chunks; |
| 75 | } |
| 76 | |
| 77 | static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, llama_seq_id seq_id) { |
| 78 | size_t n_tokens = tokens.size(); |
| 79 | for (size_t i = 0; i < n_tokens; i++) { |
| 80 | common_batch_add(batch, id: tokens[i], pos: i, seq_ids: { seq_id }, logits: true); |
| 81 | } |
| 82 | } |
| 83 | |
| 84 | static void batch_process(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) { |
| 85 | // clear previous kv_cache values (irrelevant for embeddings) |
| 86 | llama_memory_clear(mem: llama_get_memory(ctx), data: false); |
| 87 | |
| 88 | // run model |
| 89 | LOG_INF("%s: n_tokens = %d, n_seq = %d\n" , __func__, batch.n_tokens, n_seq); |
| 90 | if (llama_decode(ctx, batch) < 0) { |
| 91 | LOG_ERR("%s : failed to process\n" , __func__); |
| 92 | } |
| 93 | |
| 94 | for (int i = 0; i < batch.n_tokens; i++) { |
| 95 | if (!batch.logits[i]) { |
| 96 | continue; |
| 97 | } |
| 98 | |
| 99 | // try to get sequence embeddings - supported only when pooling_type is not NONE |
| 100 | const float * embd = llama_get_embeddings_seq(ctx, seq_id: batch.seq_id[i][0]); |
| 101 | if (embd == NULL) { |
| 102 | embd = llama_get_embeddings_ith(ctx, i); |
| 103 | if (embd == NULL) { |
| 104 | LOG_ERR("%s: failed to get embeddings for token %d\n" , __func__, i); |
| 105 | continue; |
| 106 | } |
| 107 | } |
| 108 | |
| 109 | float * out = output + batch.seq_id[i][0] * n_embd; |
| 110 | common_embd_normalize(inp: embd, out, n: n_embd, embd_norm: 2); |
| 111 | } |
| 112 | } |
| 113 | |
| 114 | int main(int argc, char ** argv) { |
| 115 | common_params params; |
| 116 | |
| 117 | if (!common_params_parse(argc, argv, params, ex: LLAMA_EXAMPLE_RETRIEVAL, print_usage)) { |
| 118 | return 1; |
| 119 | } |
| 120 | |
| 121 | common_init(); |
| 122 | |
| 123 | // For BERT models, batch size must be equal to ubatch size |
| 124 | params.n_ubatch = params.n_batch; |
| 125 | params.embedding = true; |
| 126 | |
| 127 | if (params.chunk_size <= 0) { |
| 128 | LOG_ERR("chunk_size must be positive\n" ); |
| 129 | return 1; |
| 130 | } |
| 131 | if (params.context_files.empty()) { |
| 132 | LOG_ERR("context_files must be specified\n" ); |
| 133 | return 1; |
| 134 | } |
| 135 | |
| 136 | LOG_INF("processing files:\n" ); |
| 137 | for (auto & context_file : params.context_files) { |
| 138 | LOG_INF("%s\n" , context_file.c_str()); |
| 139 | } |
| 140 | |
| 141 | std::vector<chunk> chunks; |
| 142 | for (auto & context_file : params.context_files) { |
| 143 | std::vector<chunk> file_chunk = chunk_file(filename: context_file, chunk_size: params.chunk_size, chunk_separator: params.chunk_separator); |
| 144 | chunks.insert(position: chunks.end(), first: file_chunk.begin(), last: file_chunk.end()); |
| 145 | } |
| 146 | LOG_INF("Number of chunks: %zu\n" , chunks.size()); |
| 147 | |
| 148 | llama_backend_init(); |
| 149 | llama_numa_init(numa: params.numa); |
| 150 | |
| 151 | // load the model |
| 152 | common_init_result llama_init = common_init_from_params(params); |
| 153 | |
| 154 | llama_model * model = llama_init.model.get(); |
| 155 | llama_context * ctx = llama_init.context.get(); |
| 156 | |
| 157 | if (model == NULL) { |
| 158 | LOG_ERR("%s: unable to load model\n" , __func__); |
| 159 | return 1; |
| 160 | } |
| 161 | |
| 162 | const llama_vocab * vocab = llama_model_get_vocab(model); |
| 163 | |
| 164 | const int n_ctx_train = llama_model_n_ctx_train(model); |
| 165 | const int n_ctx = llama_n_ctx(ctx); |
| 166 | |
| 167 | const enum llama_pooling_type pooling_type = llama_pooling_type(ctx); |
| 168 | if (pooling_type == LLAMA_POOLING_TYPE_NONE) { |
| 169 | LOG_ERR("%s: pooling type NONE not supported\n" , __func__); |
| 170 | return 1; |
| 171 | } |
| 172 | |
| 173 | if (n_ctx > n_ctx_train) { |
| 174 | LOG_WRN("%s: warning: model was trained on only %d context tokens (%d specified)\n" , |
| 175 | __func__, n_ctx_train, n_ctx); |
| 176 | } |
| 177 | |
| 178 | // print system information |
| 179 | { |
| 180 | LOG_INF("\n" ); |
| 181 | LOG_INF("%s\n" , common_params_get_system_info(params).c_str()); |
| 182 | } |
| 183 | |
| 184 | // max batch size |
| 185 | const uint64_t n_batch = params.n_batch; |
| 186 | GGML_ASSERT(params.n_batch >= params.n_ctx); |
| 187 | |
| 188 | // tokenize the prompts and trim |
| 189 | for (auto & chunk : chunks) { |
| 190 | auto inp = common_tokenize(ctx, text: chunk.textdata, add_special: true, parse_special: false); |
| 191 | if (inp.size() > n_batch) { |
| 192 | LOG_ERR("%s: chunk size (%lld) exceeds batch size (%lld), increase batch size and re-run\n" , |
| 193 | __func__, (long long int) inp.size(), (long long int) n_batch); |
| 194 | return 1; |
| 195 | } |
| 196 | // add eos if not present |
| 197 | if (llama_vocab_eos(vocab) >= 0 && (inp.empty() || inp.back() != llama_vocab_eos(vocab))) { |
| 198 | inp.push_back(x: llama_vocab_eos(vocab)); |
| 199 | } |
| 200 | chunk.tokens = inp; |
| 201 | } |
| 202 | |
| 203 | // tokenization stats |
| 204 | if (params.verbose_prompt) { |
| 205 | for (int i = 0; i < (int) chunks.size(); i++) { |
| 206 | LOG_INF("%s: prompt %d: '%s'\n" , __func__, i, chunks[i].textdata.c_str()); |
| 207 | LOG_INF("%s: number of tokens in prompt = %zu\n" , __func__, chunks[i].tokens.size()); |
| 208 | for (int j = 0; j < (int) chunks[i].tokens.size(); j++) { |
| 209 | LOG_INF("%6d -> '%s'\n" , chunks[i].tokens[j], common_token_to_piece(ctx, chunks[i].tokens[j]).c_str()); |
| 210 | } |
| 211 | LOG_INF("\n\n" ); |
| 212 | } |
| 213 | } |
| 214 | |
| 215 | // initialize batch |
| 216 | const int n_chunks = chunks.size(); |
| 217 | struct llama_batch batch = llama_batch_init(n_tokens: n_batch, embd: 0, n_seq_max: 1); |
| 218 | |
| 219 | // allocate output |
| 220 | const int n_embd = llama_model_n_embd(model); |
| 221 | std::vector<float> embeddings(n_chunks * n_embd, 0); |
| 222 | float * emb = embeddings.data(); |
| 223 | |
| 224 | // break into batches |
| 225 | int p = 0; // number of prompts processed already |
| 226 | int s = 0; // number of prompts in current batch |
| 227 | for (int k = 0; k < n_chunks; k++) { |
| 228 | // clamp to n_batch tokens |
| 229 | auto & inp = chunks[k].tokens; |
| 230 | |
| 231 | const uint64_t n_toks = inp.size(); |
| 232 | |
| 233 | // encode if at capacity |
| 234 | if (batch.n_tokens + n_toks > n_batch) { |
| 235 | float * out = emb + p * n_embd; |
| 236 | batch_process(ctx, batch, output: out, n_seq: s, n_embd); |
| 237 | common_batch_clear(batch); |
| 238 | p += s; |
| 239 | s = 0; |
| 240 | } |
| 241 | |
| 242 | // add to batch |
| 243 | batch_add_seq(batch, tokens: inp, seq_id: s); |
| 244 | s += 1; |
| 245 | } |
| 246 | |
| 247 | // final batch |
| 248 | float * out = emb + p * n_embd; |
| 249 | batch_process(ctx, batch, output: out, n_seq: s, n_embd); |
| 250 | |
| 251 | // save embeddings to chunks |
| 252 | for (int i = 0; i < n_chunks; i++) { |
| 253 | chunks[i].embedding = std::vector<float>(emb + i * n_embd, emb + (i + 1) * n_embd); |
| 254 | // clear tokens as they are no longer needed |
| 255 | chunks[i].tokens.clear(); |
| 256 | } |
| 257 | |
| 258 | struct llama_batch query_batch = llama_batch_init(n_tokens: n_batch, embd: 0, n_seq_max: 1); |
| 259 | |
| 260 | // start loop, receive query and return top k similar chunks based on cosine similarity |
| 261 | std::string query; |
| 262 | while (true) { |
| 263 | LOG("Enter query: " ); |
| 264 | std::getline(is&: std::cin, str&: query); |
| 265 | std::vector<int32_t> query_tokens = common_tokenize(ctx, text: query, add_special: true); |
| 266 | |
| 267 | batch_add_seq(batch&: query_batch, tokens: query_tokens, seq_id: 0); |
| 268 | |
| 269 | std::vector<float> query_emb(n_embd, 0); |
| 270 | batch_process(ctx, batch&: query_batch, output: query_emb.data(), n_seq: 1, n_embd); |
| 271 | |
| 272 | common_batch_clear(batch&: query_batch); |
| 273 | |
| 274 | // compute cosine similarities |
| 275 | { |
| 276 | std::vector<std::pair<int, float>> similarities; |
| 277 | for (int i = 0; i < n_chunks; i++) { |
| 278 | float sim = common_embd_similarity_cos(embd1: chunks[i].embedding.data(), embd2: query_emb.data(), n: n_embd); |
| 279 | similarities.push_back(x: std::make_pair(x&: i, y&: sim)); |
| 280 | } |
| 281 | |
| 282 | // sort similarities |
| 283 | std::sort(first: similarities.begin(), last: similarities.end(), comp: [](const std::pair<int, float> & a, const std::pair<int, float> & b) { |
| 284 | return a.second > b.second; |
| 285 | }); |
| 286 | |
| 287 | LOG("Top %d similar chunks:\n" , params.sampling.top_k); |
| 288 | for (int i = 0; i < std::min(a: params.sampling.top_k, b: (int) chunks.size()); i++) { |
| 289 | LOG("filename: %s\n" , chunks[similarities[i].first].filename.c_str()); |
| 290 | LOG("filepos: %lld\n" , (long long int) chunks[similarities[i].first].filepos); |
| 291 | LOG("similarity: %f\n" , similarities[i].second); |
| 292 | LOG("textdata:\n%s\n" , chunks[similarities[i].first].textdata.c_str()); |
| 293 | LOG("--------------------\n" ); |
| 294 | } |
| 295 | } |
| 296 | } |
| 297 | |
| 298 | LOG("\n" ); |
| 299 | llama_perf_context_print(ctx); |
| 300 | |
| 301 | // clean up |
| 302 | llama_batch_free(batch: query_batch); |
| 303 | llama_backend_free(); |
| 304 | } |
| 305 | |