1#include "arg.h"
2#include "common.h"
3#include "log.h"
4#include "llama.h"
5
6#include <chrono>
7#include <algorithm>
8#include <array>
9#include <atomic>
10#include <cmath>
11#include <cstdio>
12#include <cstring>
13#include <ctime>
14#include <fstream>
15#include <mutex>
16#include <random>
17#include <sstream>
18#include <thread>
19#include <vector>
20
21#if defined(_MSC_VER)
22#pragma warning(disable: 4244 4267) // possible loss of data
23#endif
24
25struct results_perplexity {
26 std::vector<llama_token> tokens;
27 double ppl_value;
28 std::vector<float> logits;
29 std::vector<float> probs;
30};
31
32struct results_log_softmax {
33 double log_softmax;
34 float logit;
35 float prob;
36};
37
38static std::vector<float> softmax(const std::vector<float>& logits) {
39 std::vector<float> probs(logits.size());
40 float max_logit = logits[0];
41 for (float v : logits) {
42 max_logit = std::max(a: max_logit, b: v);
43 }
44 double sum_exp = 0.0;
45 for (size_t i = 0; i < logits.size(); i++) {
46 // Subtract the maximum logit value from the current logit value for numerical stability
47 const float logit = logits[i] - max_logit;
48 const float exp_logit = expf(x: logit);
49 sum_exp += exp_logit;
50 probs[i] = exp_logit;
51 }
52 for (size_t i = 0; i < probs.size(); i++) {
53 probs[i] /= sum_exp;
54 }
55 return probs;
56}
57
58static results_log_softmax log_softmax(int n_vocab, const float * logits, int tok) {
59 float max_logit = logits[0];
60 for (int i = 1; i < n_vocab; ++i) {
61 max_logit = std::max(a: max_logit, b: logits[i]);
62 }
63 double sum_exp = 0.0;
64 for (int i = 0; i < n_vocab; ++i) {
65 sum_exp += expf(x: logits[i] - max_logit);
66 }
67 return {.log_softmax: logits[tok] - max_logit - log(x: sum_exp), .logit: logits[tok], .prob: expf(x: logits[tok] - max_logit) / (float) sum_exp};
68}
69
70static inline int nearest_int(float fval) {
71 //assert(fval <= 4194303.f);
72 float val = fval + 12582912.f;
73 int i; memcpy(dest: &i, src: &val, n: sizeof(int));
74 return (i & 0x007fffff) - 0x00400000;
75}
76
77static double log_softmax(int n_vocab, const float * logits, uint16_t * log_prob, int tok) {
78 float max_logit = logits[0];
79 float min_logit = logits[0];
80 for (int i = 1; i < n_vocab; ++i) {
81 max_logit = std::max(a: max_logit, b: logits[i]);
82 min_logit = std::min(a: min_logit, b: logits[i]);
83 }
84 min_logit = std::max(a: min_logit, b: max_logit - 16);
85 double sum_exp = 0.0;
86 for (int i = 0; i < n_vocab; ++i) {
87 sum_exp += expf(x: logits[i] - max_logit);
88 }
89 const float log_sum_exp = log(x: sum_exp);
90 const float min_log_prob = min_logit - max_logit - log_sum_exp;
91 const float scale = (max_logit - min_logit)/65535.f;
92 float * d = (float *)log_prob;
93 d[0] = scale;
94 d[1] = min_log_prob;
95 log_prob += 4;
96 if (scale) {
97 const float inv_scale = 1/scale;
98 for (int i = 0; i < n_vocab; ++i) {
99 log_prob[i] = logits[i] > min_logit ? nearest_int(fval: inv_scale*(logits[i] - min_logit)) : 0;
100 }
101 } else {
102 std::memset(s: log_prob, c: 0, n: n_vocab*sizeof(uint16_t));
103 }
104 return max_logit + log_sum_exp - logits[tok];
105}
106
107static void process_logits(
108 int n_vocab, const float * logits, const int * tokens, int n_token, std::vector<std::thread> & workers,
109 double & nll, double & nll2, float * logit_history, float * prob_history
110) {
111 std::mutex mutex;
112 int counter = 0;
113 auto compute = [&mutex, &counter, &nll, &nll2, logit_history, prob_history, n_vocab, logits, tokens, n_token] () {
114 double local_nll = 0;
115 double local_nll2 = 0;
116 while (true) {
117 std::unique_lock<std::mutex> lock(mutex);
118 int i = counter++;
119 if (i >= n_token) {
120 nll += local_nll; nll2 += local_nll2;
121 break;
122 }
123 lock.unlock();
124 const results_log_softmax results = log_softmax(n_vocab, logits: logits + size_t(i)*n_vocab, tok: tokens[i+1]);
125 const double v = -results.log_softmax;
126 local_nll += v;
127 local_nll2 += v*v;
128
129 logit_history[i] = results.logit;
130 prob_history[i] = results.prob;
131 }
132 };
133 for (auto & w : workers) {
134 w = std::thread(compute);
135 }
136 compute();
137 for (auto & w : workers) {
138 w.join();
139 }
140}
141
142static void process_logits(std::ostream& out, int n_vocab, const float * logits, const int * tokens, int n_token,
143 std::vector<std::thread> & workers, std::vector<uint16_t> & log_probs, double & nll, double & nll2) {
144 std::mutex mutex;
145 const int nv = 2*((n_vocab + 1)/2) + 4;
146 int counter = 0;
147 auto compute = [&mutex, &counter, &log_probs, &nll, &nll2, n_vocab, logits, tokens, n_token, nv] () {
148 double local_nll = 0;
149 double local_nll2 = 0;
150 while (true) {
151 std::unique_lock<std::mutex> lock(mutex);
152 int i = counter++;
153 if (i >= n_token) {
154 nll += local_nll; nll2 += local_nll2;
155 break;
156 }
157 lock.unlock();
158 const double v = log_softmax(n_vocab, logits: logits + size_t(i)*n_vocab, log_prob: log_probs.data() + i*nv, tok: tokens[i+1]);
159 local_nll += v;
160 local_nll2 += v*v;
161 }
162 };
163 for (auto & w : workers) {
164 w = std::thread(compute);
165 }
166 compute();
167 for (auto & w : workers) {
168 w.join();
169 }
170 out.write(s: (const char *)log_probs.data(), n: n_token*nv*sizeof(uint16_t));
171}
172
173struct kl_divergence_result {
174 double sum_nll = 0.0;
175 double sum_nll2 = 0.0;
176 double sum_nll_base = 0.0;
177 double sum_nll_base2 = 0.0;
178 double sum_nll_nll_base = 0.0;
179 double sum_kld = 0.0;
180 double sum_kld2 = 0.0;
181 double sum_p_diff = 0.0;
182 double sum_p_diff2 = 0.0;
183 double sum_p_diff4 = 0.0;
184 float max_p_diff = 0.0f;
185 size_t n_same_top = 0.0;
186 size_t count = 0.0;
187};
188
189static std::pair<double, float> log_softmax(int n_vocab, const float * logits, const uint16_t * base_log_prob, int tok, kl_divergence_result & kld) {
190 float max_logit = logits[0];
191 int imax = 0;
192 for (int i = 1; i < n_vocab; ++i) {
193 if (logits[i] > max_logit) {
194 max_logit = logits[i];
195 imax = i;
196 }
197 }
198 double sum_exp = 0.0;
199 for (int i = 0; i < n_vocab; ++i) {
200 sum_exp += expf(x: logits[i] - max_logit);
201 }
202 const float log_sum_exp = log(x: sum_exp);
203 const float * d = (const float *)base_log_prob;
204 const float scale = d[0];
205 const float min_log_prob = d[1];
206 base_log_prob += 4;
207
208 const float nll = max_logit + log_sum_exp - logits[tok];
209 kld.sum_nll += nll;
210 kld.sum_nll2 += nll*nll;
211
212 const float nll_base = -(scale*base_log_prob[tok] + min_log_prob);
213 kld.sum_nll_base += nll_base;
214 kld.sum_nll_base2 += nll_base*nll_base;
215
216 kld.sum_nll_nll_base += nll*nll_base;
217
218 max_logit += log_sum_exp;
219 double sum = 0;
220 int imax_base = -1;
221 float p_log_base_max = 0;
222 for (int i = 0; i < n_vocab; ++i) {
223 const float p_log_base = scale*base_log_prob[i] + min_log_prob;
224 if (i == 0 || p_log_base > p_log_base_max) {
225 p_log_base_max = p_log_base;
226 imax_base = i;
227 }
228 if (p_log_base > -16.f) {
229 const float p_base = expf(x: p_log_base);
230 sum += p_base * (p_log_base - logits[i] + max_logit);
231 }
232 }
233 kld.sum_kld += sum;
234 kld.sum_kld2 += sum*sum;
235 ++kld.count;
236 if (imax == imax_base) {
237 ++kld.n_same_top;
238 }
239
240 const float p_base = expf(x: -nll_base);
241 const float p = expf(x: -nll);
242 const float p_diff = p - p_base;
243 kld.sum_p_diff += p_diff;
244 const double p_diff2 = p_diff*p_diff;
245 kld.sum_p_diff2 += p_diff2;
246 kld.sum_p_diff4 += p_diff2*p_diff2;
247 kld.max_p_diff = std::max(a: kld.max_p_diff, b: std::fabs(x: p_diff));
248
249 return std::make_pair(x&: sum, y: p_diff);
250}
251
252static void process_logits(int n_vocab, const float * logits, const int * tokens, int n_token,
253 std::vector<std::thread> & workers, const std::vector<uint16_t> & base_log_probs, kl_divergence_result & kld,
254 float * kld_values, float * p_diff_values) {
255 std::mutex mutex;
256 const int nv = 2*((n_vocab + 1)/2) + 4;
257 int counter = 0;
258 auto compute = [&mutex, &counter, &base_log_probs, &kld, n_vocab, logits, tokens, n_token, nv, kld_values, p_diff_values] () {
259 kl_divergence_result local_kld;
260 while (true) {
261 std::unique_lock<std::mutex> lock(mutex);
262 int i = counter++;
263 if (i >= n_token) {
264 kld.sum_nll += local_kld.sum_nll;
265 kld.sum_nll2 += local_kld.sum_nll2;
266 kld.sum_nll_base += local_kld.sum_nll_base;
267 kld.sum_nll_base2 += local_kld.sum_nll_base2;
268 kld.sum_nll_nll_base += local_kld.sum_nll_nll_base;
269 kld.sum_kld += local_kld.sum_kld;
270 kld.sum_kld2 += local_kld.sum_kld2;
271 kld.sum_p_diff += local_kld.sum_p_diff;
272 kld.sum_p_diff2 += local_kld.sum_p_diff2;
273 kld.sum_p_diff4 += local_kld.sum_p_diff4;
274 kld.n_same_top += local_kld.n_same_top;
275 kld.max_p_diff = std::max(a: kld.max_p_diff, b: local_kld.max_p_diff);
276 kld.count += local_kld.count;
277 break;
278 }
279 lock.unlock();
280 std::pair<double, float> v = log_softmax(n_vocab, logits: logits + size_t(i)*n_vocab, base_log_prob: base_log_probs.data() + i*nv, tok: tokens[i+1], kld&: local_kld);
281 kld_values[i] = (float)v.first;
282 p_diff_values[i] = v.second;
283 }
284 };
285 for (auto & w : workers) {
286 w = std::thread(compute);
287 }
288 compute();
289 for (auto & w : workers) {
290 w.join();
291 }
292}
293
294static results_perplexity perplexity_v2(llama_context * ctx, const common_params & params) {
295 // Download: https://huggingface.co/datasets/ggml-org/ci/resolve/main/wikitext-2-raw-v1.zip
296 // Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw`
297 // Output: `perplexity: 13.5106 [114/114]`
298 // BOS tokens will be added for each chunk before eval
299
300 const llama_model * model = llama_get_model(ctx);
301 const llama_vocab * vocab = llama_model_get_vocab(model);
302
303 const bool add_bos = llama_vocab_get_add_bos(vocab);
304 GGML_ASSERT(!llama_vocab_get_add_eos(vocab));
305
306 LOG_INF("%s: tokenizing the input ..\n", __func__);
307
308 std::vector<llama_token> tokens = common_tokenize(ctx, text: params.prompt, add_special: true);
309
310 const int n_ctx = llama_n_ctx(ctx);
311
312 if (int(tokens.size()) < 2*n_ctx) {
313 LOG_ERR("%s: you need at least %d tokens to evaluate perplexity with a context of %d\n",__func__,2*n_ctx,
314 n_ctx);
315 LOG_ERR("%s: the data file you provided tokenizes to only %zu tokens\n",__func__,tokens.size());
316 return {.tokens: std::move(tokens), .ppl_value: 0., .logits: {}, .probs: {}};
317 }
318
319 std::vector<float> logit_history;
320 std::vector<float> prob_history;
321
322 logit_history.resize(new_size: tokens.size());
323 prob_history.resize(new_size: tokens.size());
324
325 if (params.ppl_stride <= 0) {
326 LOG_ERR("%s: stride is %d but must be greater than zero!\n",__func__,params.ppl_stride);
327 return {.tokens: tokens, .ppl_value: -1, .logits: logit_history, .probs: prob_history};
328 }
329
330 const int calc_chunk = n_ctx;
331
332 LOG_INF("%s: have %zu tokens. Calculation chunk = %d\n", __func__, tokens.size(), calc_chunk);
333
334 if (int(tokens.size()) <= calc_chunk) {
335 LOG_ERR("%s: there are only %zu tokens, this is not enough for a context size of %d and stride %d\n",__func__,
336 tokens.size(), n_ctx, params.ppl_stride);
337 return {.tokens: tokens, .ppl_value: -1, .logits: logit_history, .probs: prob_history};
338 }
339
340 const int n_chunk_max = (tokens.size() - calc_chunk + params.ppl_stride - 1) / params.ppl_stride;
341
342 const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(a: params.n_chunks, b: n_chunk_max);
343 const int n_batch = params.n_batch;
344
345 const int n_vocab = llama_vocab_n_tokens(vocab);
346
347 int count = 0;
348 double nll = 0.0;
349
350 LOG_INF("%s: calculating perplexity over %d chunks, batch_size=%d\n", __func__, n_chunk, n_batch);
351
352 for (int i = 0; i < n_chunk; ++i) {
353 const int start = i * params.ppl_stride;
354 const int end = start + calc_chunk;
355
356 const int num_batches = (calc_chunk + n_batch - 1) / n_batch;
357 //LOG_DBG("%s: evaluating %d...%d using %d batches\n", __func__, start, end, num_batches);
358
359 std::vector<float> logits;
360
361 const auto t_start = std::chrono::high_resolution_clock::now();
362
363 // clear the KV cache
364 llama_memory_clear(mem: llama_get_memory(ctx), data: true);
365
366 llama_batch batch = llama_batch_init(n_tokens: n_batch, embd: 0, n_seq_max: 1);
367
368 for (int j = 0; j < num_batches; ++j) {
369 const int batch_start = start + j * n_batch;
370 const int batch_size = std::min(a: end - batch_start, b: n_batch);
371
372 common_batch_clear(batch);
373 for (int i = 0; i < batch_size; i++) {
374 common_batch_add(batch, id: tokens[batch_start + i], pos: j*n_batch + i, seq_ids: {0}, logits: true);
375 }
376
377 //LOG_DBG(" Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch);
378 if (llama_decode(ctx, batch)) {
379 //LOG_ERR("%s : failed to eval\n", __func__);
380 llama_batch_free(batch);
381 return {.tokens: tokens, .ppl_value: -1, .logits: logit_history, .probs: prob_history};
382 }
383
384 // save original token and restore it after eval
385 const auto token_org = tokens[batch_start];
386
387 // add BOS token for the first batch of each chunk
388 if (add_bos && j == 0) {
389 tokens[batch_start] = llama_vocab_bos(vocab);
390 }
391
392 const auto * batch_logits = llama_get_logits(ctx);
393 logits.insert(position: logits.end(), first: batch_logits, last: batch_logits + size_t(batch_size) * n_vocab);
394
395 if (j == 0) {
396 tokens[batch_start] = token_org;
397 }
398 }
399
400 llama_batch_free(batch);
401
402 const auto t_end = std::chrono::high_resolution_clock::now();
403
404 if (i == 0) {
405 const float t_total = std::chrono::duration<float>(t_end - t_start).count();
406 LOG_INF("%s: %.2f seconds per pass - ETA ", __func__, t_total);
407 int total_seconds = (int)(t_total * n_chunk);
408 if (total_seconds >= 60*60) {
409 LOG("%d hours ", total_seconds / (60*60));
410 total_seconds = total_seconds % (60*60);
411 }
412 LOG("%.2f minutes\n", total_seconds / 60.0);
413 }
414
415 //LOG_DBG("%s: using tokens %d...%d\n",__func__,params.n_ctx - params.ppl_stride + start, params.n_ctx + start);
416 for (int j = n_ctx - params.ppl_stride - 1; j < n_ctx - 1; ++j) {
417 // Calculate probability of next token, given the previous ones.
418 const std::vector<float> tok_logits(
419 logits.begin() + size_t(j + 0) * n_vocab,
420 logits.begin() + size_t(j + 1) * n_vocab);
421
422 const float prob = softmax(logits: tok_logits)[tokens[start + j + 1]];
423 logit_history[start + j + 1] = tok_logits[tokens[start + j + 1]];
424 prob_history[start + j + 1] = prob;
425
426 nll += -std::log(x: prob);
427 ++count;
428 }
429 // perplexity is e^(average negative log-likelihood)
430 if (params.ppl_output_type == 0) {
431 LOG("[%d]%.4lf,", i + 1, std::exp(nll / count));
432 } else {
433 LOG("%8d %.4lf\n", i*params.ppl_stride, std::exp(nll / count));
434 }
435 }
436 LOG("\n");
437
438 return {.tokens: tokens, .ppl_value: std::exp(x: nll / count), .logits: logit_history, .probs: prob_history};
439}
440
441static results_perplexity perplexity(llama_context * ctx, const common_params & params, const int32_t n_ctx) {
442 if (params.ppl_stride > 0) {
443 return perplexity_v2(ctx, params);
444 }
445
446 // Download: https://huggingface.co/datasets/ggml-org/ci/resolve/main/wikitext-2-raw-v1.zip
447 // Run `./llama-perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw`
448 // Output: `perplexity: 13.5106 [114/114]`
449 // BOS tokens will be added for each chunk before eval
450
451 const llama_model * model = llama_get_model(ctx);
452 const llama_vocab * vocab = llama_model_get_vocab(model);
453
454 const bool add_bos = llama_vocab_get_add_bos(vocab);
455 GGML_ASSERT(!llama_vocab_get_add_eos(vocab));
456
457 std::ofstream logits_stream;
458 if (!params.logits_file.empty()) {
459 logits_stream.open(s: params.logits_file.c_str(), mode: std::ios::binary);
460 if (!logits_stream.is_open()) {
461 LOG_ERR("%s: failed to open %s for writing\n", __func__, params.logits_file.c_str());
462 return {};
463 }
464 LOG_INF("%s: saving all logits to %s\n", __func__, params.logits_file.c_str());
465 logits_stream.write(s: "_logits_", n: 8);
466 logits_stream.write(s: reinterpret_cast<const char *>(&n_ctx), n: sizeof(n_ctx));
467 }
468
469 auto tim1 = std::chrono::high_resolution_clock::now();
470 LOG_INF("%s: tokenizing the input ..\n", __func__);
471
472 std::vector<llama_token> tokens = common_tokenize(ctx, text: params.prompt, add_special: true);
473
474 auto tim2 = std::chrono::high_resolution_clock::now();
475 LOG_INF("%s: tokenization took %g ms\n",__func__,1e-3*std::chrono::duration_cast<std::chrono::microseconds>(tim2-tim1).count());
476
477 if (int(tokens.size()) < 2*n_ctx) {
478 LOG_ERR("%s: you need at least %d tokens to evaluate perplexity with a context of %d\n",__func__,2*n_ctx,
479 n_ctx);
480 LOG_ERR("%s: the data file you provided tokenizes to only %zu tokens\n",__func__,tokens.size());
481 return {.tokens: std::move(tokens), .ppl_value: 0., .logits: {}, .probs: {}};
482 }
483
484 std::vector<float> logit_history;
485 logit_history.resize(new_size: tokens.size());
486
487 std::vector<float> prob_history;
488 prob_history.resize(new_size: tokens.size());
489
490 const int n_chunk_max = tokens.size() / n_ctx;
491
492 const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(a: params.n_chunks, b: n_chunk_max);
493 const int n_batch = params.n_batch;
494
495 const int n_vocab = llama_vocab_n_tokens(vocab);
496
497 int count = 0;
498 double nll = 0.0;
499 double nll2 = 0.0;
500
501 const int num_batches = (n_ctx + n_batch - 1) / n_batch;
502 const int n_seq = std::max(a: 1, b: n_batch / n_ctx);
503
504 GGML_ASSERT(n_batch < n_ctx || n_batch % n_ctx == 0);
505 GGML_ASSERT(params.n_ctx == n_seq * n_ctx);
506
507 llama_batch batch = llama_batch_init(n_tokens: std::min(a: n_batch, b: n_ctx*n_seq), embd: 0, n_seq_max: 1);
508
509 std::vector<float> logits;
510 if (num_batches > 1) {
511 logits.reserve(n: size_t(n_ctx) * n_vocab);
512 }
513
514 LOG_INF("%s: calculating perplexity over %d chunks, n_ctx=%d, batch_size=%d, n_seq=%d\n", __func__, n_chunk, n_ctx, n_batch, n_seq);
515
516 std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1);
517
518 std::vector<uint16_t> log_probs;
519 if (!params.logits_file.empty()) {
520 logits_stream.write(s: (const char *)&n_vocab, n: sizeof(n_vocab));
521 logits_stream.write(s: (const char *)&n_chunk, n: sizeof(n_chunk));
522 logits_stream.write(s: (const char *)tokens.data(), n: n_chunk*n_ctx*sizeof(tokens[0]));
523 const int nv = 2*((n_vocab + 1)/2) + 4;
524 log_probs.resize(new_size: n_ctx * nv);
525 }
526
527 // We get the logits for all the tokens in the context window (params.n_ctx)
528 // from llama_decode below. Now, based on https://huggingface.co/docs/transformers/perplexity,
529 // calculate the perplexity over the last half of the window (so the model always has
530 // some context to predict the token).
531 //
532 // We rely on the fact that attention in the forward pass only looks at previous
533 // tokens here, so the logits returned for each token are an accurate representation
534 // of what the model would have predicted at that point.
535 //
536 // Example, we have a context window of 512, we will compute perplexity for each of the
537 // last 256 tokens. Then, we split the input up into context window size chunks to
538 // process the entire prompt.
539 const int first = n_ctx/2;
540
541 for (int i = 0; i < n_chunk; i += n_seq) {
542 const int start = i * n_ctx;
543 const int end = start + n_ctx;
544
545 const int n_seq_batch = std::min(a: n_seq, b: n_chunk - i);
546
547 const auto t_start = std::chrono::high_resolution_clock::now();
548
549 // clear the KV cache
550 llama_memory_clear(mem: llama_get_memory(ctx), data: true);
551
552 for (int j = 0; j < num_batches; ++j) {
553 const int batch_start = start + j * n_batch;
554 const int batch_size = std::min(a: end - batch_start, b: n_batch);
555
556 int n_outputs = 0;
557
558 batch.n_tokens = 0;
559 for (int seq = 0; seq < n_seq_batch; seq++) {
560 int seq_start = batch_start + seq*n_ctx;
561
562 // save original token and restore it after decode
563 const auto token_org = tokens[seq_start];
564
565 // add BOS token for the first batch of each chunk
566 if (add_bos && j == 0) {
567 tokens[seq_start] = llama_vocab_bos(vocab);
568 }
569
570 for (int k = 0; k < batch_size; ++k) {
571 const int idx = seq*n_ctx + k;
572 batch.token [idx] = tokens[seq_start + k];
573 batch.pos [idx] = j*n_batch + k;
574 batch.n_seq_id[idx] = 1;
575 batch.seq_id [idx][0] = seq;
576 batch.logits [idx] = batch.pos[idx] >= first ? 1 : 0;
577
578 n_outputs += batch.logits[idx] != 0;
579 }
580 batch.n_tokens += batch_size;
581
582 // restore the original token in case it was set to BOS
583 tokens[seq_start] = token_org;
584 }
585
586 if (llama_decode(ctx, batch)) {
587 LOG_INF("%s : failed to decode\n", __func__);
588 return {.tokens: tokens, .ppl_value: -1, .logits: logit_history, .probs: prob_history};
589 }
590
591 if (num_batches > 1 && n_outputs > 0) {
592 const auto * batch_logits = llama_get_logits(ctx);
593 logits.insert(position: logits.end(), first: batch_logits, last: batch_logits + size_t(n_outputs) * n_vocab);
594 }
595 }
596
597
598 if (i == 0) {
599 llama_synchronize(ctx);
600 const auto t_end = std::chrono::high_resolution_clock::now();
601 const float t_total = std::chrono::duration<float>(t_end - t_start).count();
602 LOG_INF("%s: %.2f seconds per pass - ETA ", __func__, t_total);
603 int total_seconds = (int)(t_total*n_chunk/n_seq);
604 if (total_seconds >= 60*60) {
605 LOG("%d hours ", total_seconds / (60*60));
606 total_seconds = total_seconds % (60*60);
607 }
608 LOG("%.2f minutes\n", total_seconds / 60.0);
609 }
610
611 for (int seq = 0; seq < n_seq_batch; seq++) {
612 const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits_ith(ctx, i: seq*n_ctx + first);
613
614 llama_token * tokens_data = tokens.data() + start + seq*n_ctx + first;
615 if (!params.logits_file.empty()) {
616 process_logits(out&: logits_stream, n_vocab, logits: all_logits,
617 tokens: tokens_data, n_token: n_ctx - 1 - first,
618 workers, log_probs, nll, nll2);
619 } else {
620 process_logits(n_vocab, logits: all_logits,
621 tokens: tokens_data, n_token: n_ctx - 1 - first,
622 workers, nll, nll2,
623 logit_history: logit_history.data() + start + seq*n_ctx + first,
624 prob_history: prob_history.data() + start + seq*n_ctx + first);
625 }
626 count += n_ctx - first - 1;
627
628 // perplexity is e^(average negative log-likelihood)
629 if (params.ppl_output_type == 0) {
630 LOG("[%d]%.4lf,", i + seq + 1, std::exp(nll / count));
631 } else {
632 double av = nll/count;
633 double av2 = nll2/count - av*av;
634 if (av2 > 0) {
635 av2 = sqrt(x: av2/(count-1));
636 }
637 LOG("%8d %.4lf %4lf %4lf\n", i*n_ctx, std::exp(nll / count), av, av2);
638 }
639 }
640
641 logits.clear();
642 }
643 LOG("\n");
644
645 nll2 /= count;
646 nll /= count;
647 const double ppl = exp(x: nll);
648 nll2 -= nll * nll;
649 if (nll2 > 0) {
650 nll2 = sqrt(x: nll2/(count-1));
651 LOG_INF("Final estimate: PPL = %.4lf +/- %.5lf\n", ppl, nll2*ppl);
652 } else {
653 LOG_ERR("Unexpected negative standard deviation of log(prob)\n");
654 }
655
656 llama_batch_free(batch);
657
658 return {.tokens: tokens, .ppl_value: ppl, .logits: logit_history, .probs: prob_history};
659}
660
661static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<float> & batch_logits, int n_batch, int n_vocab) {
662 int prev_outputs = 0;
663 for (int i = 0; i < (int) batch.n_tokens; i += n_batch) {
664 const int n_tokens = std::min<int>(a: n_batch, b: batch.n_tokens - i);
665
666 llama_batch batch_view = {
667 .n_tokens: n_tokens,
668 .token: batch.token + i,
669 .embd: nullptr,
670 .pos: batch.pos + i,
671 .n_seq_id: batch.n_seq_id + i,
672 .seq_id: batch.seq_id + i,
673 .logits: batch.logits + i,
674 };
675
676 const int ret = llama_decode(ctx, batch: batch_view);
677 if (ret != 0) {
678 LOG_ERR("failed to decode the batch, n_batch = %d, ret = %d\n", n_batch, ret);
679 return false;
680 }
681
682 int n_outputs = 0;
683 for (int i = 0; i < n_tokens; ++i) {
684 n_outputs += batch_view.logits[i] != 0;
685 }
686
687 memcpy(dest: batch_logits.data() + size_t(prev_outputs)*n_vocab, src: llama_get_logits(ctx), n: size_t(n_outputs)*n_vocab*sizeof(float));
688
689 prev_outputs += n_outputs;
690 }
691
692 return true;
693}
694
695#define K_TOKEN_CHUNK 4
696
697static void compute_logprobs(const float * batch_logits, int n_vocab, std::vector<std::thread>& workers,
698 const std::vector<std::pair<size_t, llama_token>>& eval_pairs, std::vector<float>& eval_results) {
699 if (eval_results.size() != eval_pairs.size()) {
700 eval_results.resize(new_size: eval_pairs.size());
701 }
702 if (eval_pairs.empty()) {
703 return;
704 }
705
706 size_t max_threads = std::min(a: (eval_pairs.size() + K_TOKEN_CHUNK - 1)/K_TOKEN_CHUNK, b: workers.size());
707
708 std::atomic<int> counter(0);
709 auto compute = [&counter, &eval_pairs, &eval_results, batch_logits, n_vocab] () {
710 float local_logprobs[K_TOKEN_CHUNK];
711 while (true) {
712 const size_t first = counter.fetch_add(K_TOKEN_CHUNK, m: std::memory_order_relaxed);
713 if (first >= eval_results.size()) {
714 break;
715 }
716 const size_t last = std::min(a: first + K_TOKEN_CHUNK, b: eval_results.size());
717 for (size_t i = first; i < last; ++i) {
718 const auto * logits = batch_logits + eval_pairs[i].first * n_vocab;
719 float max_logit = logits[0];
720 for (int j = 1; j < n_vocab; ++j) {
721 max_logit = std::max(a: max_logit, b: logits[j]);
722 }
723 float sum_p = 0.f;
724 for (int j = 0; j < n_vocab; ++j) {
725 sum_p += expf(x: logits[j] - max_logit);
726 }
727 local_logprobs[i - first] = logits[eval_pairs[i].second] - max_logit - std::log(x: sum_p);
728 }
729 std::memcpy(dest: eval_results.data() + first, src: local_logprobs, n: (last - first)*sizeof(float));
730 }
731 };
732
733 for (size_t it = 0; it < max_threads; ++it) {
734 workers[it] = std::thread(compute);
735 }
736 for (size_t it = 0; it < max_threads; ++it) {
737 workers[it].join();
738 }
739}
740
741static void hellaswag_score(llama_context * ctx, const common_params & params) {
742 const llama_model * model = llama_get_model(ctx);
743 const llama_vocab * vocab = llama_model_get_vocab(model);
744
745 // Calculates hellaswag score (acc_norm) from prompt
746 //
747 // Data extracted from the HellaSwag validation dataset (MIT license) https://github.com/rowanz/hellaswag/blob/master/data/hellaswag_val.jsonl
748 // All used data fields are preprocessed as in https://github.com/EleutherAI/lm-evaluation-harness/blob/df3da98c5405deafd519c2ddca52bb7c3fe36bef/lm_eval/tasks/hellaswag.py#L62-L68
749 //
750 // All 10042 tasks should be extracted to keep the results standardized like other implementations.
751 //
752 // Datafile layout:
753 // ['??'] denotes json fields
754 // 6 lines per task:
755 // ['activity_label'] + ": " +['ctx'] - The first part of the query, the context
756 // ['label'] - The index the best common sense ending aka gold ending
757 // ['endings'][0] - Endings added to the first part of the query
758 // ['endings'][1]
759 // ['endings'][2]
760 // ['endings'][3]
761
762 std::vector<std::string> prompt_lines;
763 std::istringstream strstream(params.prompt);
764 std::string line;
765
766 while (std::getline(in&: strstream,str&: line,delim: '\n')) {
767 prompt_lines.push_back(x: line);
768 }
769
770 if (prompt_lines.size() % 6 != 0) {
771 LOG_ERR("%s : number of lines in prompt not a multiple of 6.\n", __func__);
772 return;
773 }
774
775 size_t hs_task_count = prompt_lines.size()/6;
776 LOG_INF("%s : loaded %zu tasks from prompt.\n", __func__, hs_task_count);
777
778 const bool is_spm = llama_vocab_type(vocab) == LLAMA_VOCAB_TYPE_SPM;
779 LOG_INF("================================= is_spm = %d\n", is_spm);
780
781 // The tasks should be randomized so the score stabilizes quickly.
782 bool randomize_tasks = true;
783
784 // Number of tasks to use when computing the score
785 if (params.hellaswag_tasks < hs_task_count) {
786 hs_task_count = params.hellaswag_tasks;
787 }
788
789 // The random seed should not impact the final result if the computation is done over enough tasks, so kept hardcoded for now
790 std::mt19937 rng(1);
791
792 // Dataholder for hellaswag tasks
793 struct hs_data_t {
794 std::string context;
795 size_t gold_ending_idx;
796 std::string ending[4];
797 size_t ending_logprob_count[4];
798 double ending_logprob[4];
799
800 size_t i_logits; // starting index of logits in the llama_batch
801 size_t common_prefix; // max number of initial tokens that are the same in all sentences
802 size_t required_tokens; // needed number of tokens to evaluate all 4 endings
803 std::vector<llama_token> seq_tokens[4];
804 };
805
806 LOG_INF("%s : selecting %zu %s tasks.\n", __func__, hs_task_count, (randomize_tasks?"randomized":"the first") );
807
808 // Select and read data from prompt lines
809 std::vector<hs_data_t> hs_data(hs_task_count);
810 for (size_t i = 0; i < hs_task_count; i++) {
811 size_t idx = i;
812
813 auto & hs_cur = hs_data[i];
814
815 // Select a random example of those left in the prompt
816 if (randomize_tasks) {
817 std::uniform_int_distribution<size_t> dist(0, prompt_lines.size()/6-1 ) ;
818 idx = dist(rng);
819 }
820
821 hs_cur.context = prompt_lines[idx*6];
822 hs_cur.gold_ending_idx = std::stoi( str: prompt_lines[idx*6+1] );
823 for (size_t j = 0; j < 4; j++) {
824 hs_cur.ending[j] = prompt_lines[idx*6+2+j];
825 hs_cur.seq_tokens[j] = common_tokenize(ctx, text: hs_cur.context + " " + hs_cur.ending[j], add_special: true);
826 }
827
828 // determine the common prefix of the endings
829 hs_cur.common_prefix = 0;
830 for (size_t k = 0; k < hs_cur.seq_tokens[0].size(); k++) {
831 if (hs_cur.seq_tokens[0][k] != hs_cur.seq_tokens[1][k] ||
832 hs_cur.seq_tokens[0][k] != hs_cur.seq_tokens[2][k] ||
833 hs_cur.seq_tokens[0][k] != hs_cur.seq_tokens[3][k]) {
834 break;
835 }
836 hs_cur.common_prefix++;
837 }
838 hs_cur.required_tokens = hs_cur.common_prefix +
839 hs_cur.seq_tokens[0].size() - hs_cur.common_prefix +
840 hs_cur.seq_tokens[1].size() - hs_cur.common_prefix +
841 hs_cur.seq_tokens[2].size() - hs_cur.common_prefix +
842 hs_cur.seq_tokens[3].size() - hs_cur.common_prefix;
843
844 //GGML_ASSERT(hs_cur.common_prefix >= ::llama_tokenize(ctx, hs_cur.context, true).size());
845
846 // Delete the selected random example from the prompt
847 if (randomize_tasks) {
848 prompt_lines.erase( first: std::next(x: prompt_lines.begin(),n: idx*6) , last: std::next(x: prompt_lines.begin(),n: idx*6+6) );
849 }
850 }
851
852 LOG_INF("%s : calculating hellaswag score over selected tasks.\n", __func__);
853
854 LOG("\ntask\tacc_norm\t95%% confidence interval\n");
855
856 double acc = 0.0f;
857
858 const int n_ctx = llama_n_ctx(ctx);
859 const int n_batch = params.n_batch;
860
861 const int n_vocab = llama_vocab_n_tokens(vocab);
862
863 const int max_tasks_per_batch = 32;
864 const int max_seq = std::min(a: 4*max_tasks_per_batch, b: (int) llama_n_seq_max(ctx));
865
866 llama_batch batch = llama_batch_init(n_tokens: n_ctx, embd: 0, n_seq_max: 4);
867
868 std::vector<float> tok_logits(n_vocab);
869 // TODO: this could be made smaller; it's currently the worst-case size
870 std::vector<float> batch_logits(size_t(n_ctx)*n_vocab);
871
872 std::vector<std::pair<size_t, llama_token>> eval_pairs;
873 std::vector<float> eval_results;
874 std::vector<std::thread> workers(std::thread::hardware_concurrency());
875
876 for (size_t i0 = 0; i0 < hs_task_count; i0++) {
877 int n_cur = 0;
878
879 size_t i1 = i0;
880 size_t i_logits = 0; // this tells us how many logits were needed before this point in the batch
881
882 common_batch_clear(batch);
883
884 // batch as much tasks as possible into the available context
885 // each task has 4 unique sequence ids - one for each ending
886 // the common prefix is shared among the 4 sequences to save tokens
887 // we extract logits only from the last common token and from all ending tokens of each sequence
888 while (n_cur + (int) hs_data[i1].required_tokens <= n_ctx) {
889 auto & hs_cur = hs_data[i1];
890 int n_logits = 0;
891
892 const int s0 = 4*(i1 - i0);
893 if (s0 + 4 > max_seq) {
894 break;
895 }
896
897 for (size_t i = 0; i < hs_cur.common_prefix; ++i) {
898 common_batch_add(batch, id: hs_cur.seq_tokens[0][i], pos: i, seq_ids: { s0 + 0, s0 + 1, s0 + 2, s0 + 3 }, logits: false);
899 }
900 batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix
901 n_logits += 1;
902
903 for (int s = 0; s < 4; ++s) {
904 const size_t seq_tokens_size = hs_cur.seq_tokens[s].size();
905 // TODO: don't evaluate the last token of each sequence
906 for (size_t i = hs_cur.common_prefix; i < seq_tokens_size; ++i) {
907 const bool needs_logits = i < seq_tokens_size - 1;
908 common_batch_add(batch, id: hs_cur.seq_tokens[s][i], pos: i, seq_ids: { s0 + s }, logits: needs_logits);
909 n_logits += needs_logits;
910 }
911 }
912
913 hs_cur.i_logits = i_logits;
914 i_logits += n_logits;
915
916 n_cur += hs_data[i1].required_tokens;
917 if (++i1 == hs_task_count) {
918 break;
919 }
920 }
921
922 if (i0 == i1) {
923 LOG_ERR("%s : task %zu does not fit in the context window (requires %lu tokens)\n", __func__, i0, hs_data[i0].required_tokens);
924 return;
925 }
926
927 llama_memory_clear(mem: llama_get_memory(ctx), data: true);
928
929 // decode all tasks [i0, i1)
930 if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) {
931 LOG_ERR("%s: llama_decode() failed\n", __func__);
932 return;
933 }
934
935 // Compute log-probs in parallel
936 // First we collect all tasks
937 eval_pairs.clear();
938 for (size_t i = i0; i < i1; ++i) {
939 auto & hs_cur = hs_data[i];
940 size_t li = 1; // skip the last logit of the common prefix (computed separately below)
941 for (int s = 0; s < 4; ++s) {
942 for (size_t j = hs_cur.common_prefix; j < hs_cur.seq_tokens[s].size() - 1; j++) {
943 eval_pairs.emplace_back(args: hs_cur.i_logits + li++, args&: hs_cur.seq_tokens[s][j + 1]);
944 }
945 }
946 }
947 // Then we do the actual calculation
948 compute_logprobs(batch_logits: batch_logits.data(), n_vocab, workers, eval_pairs, eval_results);
949
950 size_t ir = 0;
951
952 // compute the logprobs for each ending of the decoded tasks
953 for (size_t i = i0; i < i1; ++i) {
954 auto & hs_cur = hs_data[i];
955
956 // get the logits of the last token of the common prefix
957 std::memcpy(dest: tok_logits.data(), src: batch_logits.data() + hs_cur.i_logits*n_vocab, n: n_vocab*sizeof(float));
958
959 const auto first_probs = softmax(logits: tok_logits);
960
961 for (int s = 0; s < 4; ++s) {
962 hs_cur.ending_logprob_count[s] = 1;
963 hs_cur.ending_logprob[s] = std::log(x: first_probs[hs_cur.seq_tokens[s][hs_cur.common_prefix]]);
964 for (size_t j = hs_cur.common_prefix; j < hs_cur.seq_tokens[s].size() - 1; j++) {
965 hs_cur.ending_logprob[s] += eval_results[ir++];
966 hs_cur.ending_logprob_count[s]++;
967 }
968 hs_cur.ending_logprob[s] /= hs_cur.ending_logprob_count[s];
969 }
970
971 // Find the ending with maximum logprob
972 size_t ending_logprob_max_idx = 0;
973 double ending_logprob_max_val = hs_cur.ending_logprob[0];
974 for (size_t s = 1; s < 4; s++) {
975 if (hs_cur.ending_logprob[s] > ending_logprob_max_val) {
976 ending_logprob_max_idx = s;
977 ending_logprob_max_val = hs_cur.ending_logprob[s];
978 }
979 }
980
981 //LOG("max logprob ending idx %lu, gold ending idx %lu\n", ending_logprob_max_idx, hs_cur.gold_ending_idx);
982
983 // If the gold ending got the maximum logprobe add one accuracy point
984 if (ending_logprob_max_idx == hs_cur.gold_ending_idx) {
985 acc += 1.0;
986 }
987
988 double freq = acc / double(i + 1);
989
990 const double za = 1.95996398454;
991
992 // // Wald normal approx
993 // double conf =za*sqrt(freq*(1-freq)/double(i + 1));
994 // LOG("%zu\t%.8lf +/- %.8lf\n", i + 1, freq*100.0, conf*100.0);
995
996 // Wilson score interval, more accurate
997 double z = za * za / double(i + 1);
998 double cnf = z * sqrt(x: double(i + 1) * (4.0 * freq * (1 - freq) + z)) / (za + za);
999 double a = (freq + z * 0.5 - cnf) / (1.0 + z);
1000 double b = (freq + z * 0.5 + cnf) / (1.0 + z);
1001
1002 // Print the accumulated accuracy mean x 100 and confidence interval
1003 LOG("%zu\t%3.8lf%%\t[%3.4lf%%, %3.4lf%%]\n", i + 1, freq * 100.0, a * 100.0, b * 100.0);
1004 }
1005
1006 i0 = i1 - 1;
1007 }
1008
1009 llama_batch_free(batch);
1010
1011 LOG("\n");
1012}
1013
1014struct winogrande_entry {
1015 std::string first;
1016 std::string second;
1017 std::array<std::string, 2> choices;
1018 int answer;
1019
1020 size_t i_logits;
1021 size_t common_prefix;
1022 size_t required_tokens;
1023 size_t n_base1; // number of tokens for context + choice 1
1024 size_t n_base2; // number of tokens for context + choice 2
1025 std::vector<llama_token> seq_tokens[2];
1026};
1027
1028static std::vector<winogrande_entry> load_winogrande_from_csv(const std::string & prompt) {
1029 std::vector<winogrande_entry> result;
1030 std::istringstream in(prompt);
1031 std::string line;
1032 std::array<int, 4> comma_pos;
1033 while (true) {
1034 std::getline(is&: in, str&: line);
1035 if (in.fail() || in.eof()) break;
1036 int ipos = 0;
1037 bool quote_open = false;
1038 for (int i = 0; i < int(line.size()); ++i) {
1039 if (!quote_open) {
1040 if (line[i] == ',') {
1041 comma_pos[ipos++] = i;
1042 if (ipos == 4) break;
1043 }
1044 else if (line[i] == '"') {
1045 quote_open = true;
1046 }
1047 }
1048 else {
1049 if (line[i] == '"') {
1050 quote_open = false;
1051 }
1052 }
1053 }
1054 if (ipos != 4) {
1055 LOG_ERR("%s: failed to find comma separators in <%s>\n", __func__, line.c_str());
1056 continue;
1057 }
1058 auto sentence = line[comma_pos[0]+1] == '"' ? line.substr(pos: comma_pos[0]+2, n: comma_pos[1] - comma_pos[0] - 3)
1059 : line.substr(pos: comma_pos[0]+1, n: comma_pos[1] - comma_pos[0] - 1);
1060 auto choice1 = line.substr(pos: comma_pos[1]+1, n: comma_pos[2] - comma_pos[1] - 1);
1061 auto choice2 = line.substr(pos: comma_pos[2]+1, n: comma_pos[3] - comma_pos[2] - 1);
1062 auto answer = line.substr(pos: comma_pos[3]+1, n: line.size() - comma_pos[3] - 1);
1063 auto index = line.substr(pos: 0, n: comma_pos[0]);
1064 int where = 0;
1065 for ( ; where < int(sentence.size()); ++where) {
1066 if (sentence[where] == '_') break;
1067 }
1068 if (where == int(sentence.size())) {
1069 LOG_ERR("%s: no _ in <%s>\n", __func__, sentence.c_str());
1070 continue;
1071 }
1072 std::istringstream stream(answer.c_str());
1073 int i_answer; stream >> i_answer;
1074 if (stream.fail() || i_answer < 1 || i_answer > 2) {
1075 LOG_ERR("%s: failed to parse answer <%s>\n", __func__, answer.c_str());
1076 continue;
1077 }
1078 result.emplace_back();
1079 auto& wg = result.back();
1080 wg.first = sentence.substr(pos: 0, n: where);
1081 wg.second = sentence.substr(pos: where + 1, n: sentence.size() - where - 1);
1082 wg.choices[0] = std::move(choice1);
1083 wg.choices[1] = std::move(choice2);
1084 wg.answer = i_answer;
1085 }
1086 return result;
1087}
1088
1089/*
1090 * Evaluates the Winogrande score.
1091 * Uses a CSV containing task index, dentence, choice 1, choice 2, answer (1 or 2)
1092 * You can get one such dataset from e.g. https://huggingface.co/datasets/ikawrakow/winogrande-eval-for-llama.cpp
1093 * As an example, the 1st row in the above dataset is
1094 *
1095 * 0,Sarah was a much better surgeon than Maria so _ always got the easier cases.,Sarah,Maria,2
1096 *
1097 */
1098static void winogrande_score(llama_context * ctx, const common_params & params) {
1099 const llama_model * model = llama_get_model(ctx);
1100 const llama_vocab * vocab = llama_model_get_vocab(model);
1101
1102 constexpr int k_min_trailing_ctx = 3;
1103
1104 auto data = load_winogrande_from_csv(prompt: params.prompt);
1105 if (data.empty()) {
1106 LOG_ERR("%s: no tasks\n", __func__);
1107 return;
1108 }
1109
1110 LOG_INF("%s : loaded %zu tasks from prompt.\n", __func__, data.size());
1111
1112 if (params.winogrande_tasks > 0 && params.winogrande_tasks < data.size()) {
1113 LOG_INF("%s : selecting %zu random tasks\n", __func__, params.winogrande_tasks);
1114 std::mt19937 rng(1);
1115 std::vector<int> aux(data.size());
1116 for (int i = 0; i < int(data.size()); ++i) {
1117 aux[i] = i;
1118 }
1119 float scale = 1/(1.f + (float)rng.max());
1120 std::vector<winogrande_entry> selected;
1121 selected.resize(new_size: params.winogrande_tasks);
1122 for (int i = 0; i < int(params.winogrande_tasks); ++i) {
1123 int j = int(scale*rng()*aux.size());
1124 selected[i] = std::move(data[aux[j]]);
1125 aux[j] = aux.back();
1126 aux.pop_back();
1127 }
1128 data = std::move(selected);
1129 }
1130
1131 LOG_INF("%s : tokenizing selected tasks\n", __func__);
1132
1133 for (auto & task : data) {
1134 task.seq_tokens[0] = common_tokenize(ctx, text: task.first + task.choices[0] + task.second, add_special: true);
1135 task.seq_tokens[1] = common_tokenize(ctx, text: task.first + task.choices[1] + task.second, add_special: true);
1136
1137 task.common_prefix = 0;
1138 for (size_t k = 0; k < task.seq_tokens[0].size(); k++) {
1139 if (task.seq_tokens[0][k] != task.seq_tokens[1][k]) {
1140 break;
1141 }
1142 task.common_prefix++;
1143 }
1144
1145 // TODO: the last token of each of the sequences don't need to be evaluated
1146 task.required_tokens = task.common_prefix +
1147 task.seq_tokens[0].size() - task.common_prefix +
1148 task.seq_tokens[1].size() - task.common_prefix;
1149
1150 task.n_base1 = common_tokenize(ctx, text: task.first + task.choices[0], add_special: true).size();
1151 task.n_base2 = common_tokenize(ctx, text: task.first + task.choices[1], add_special: true).size();
1152 }
1153
1154 LOG_INF("%s : calculating winogrande score over selected tasks.\n", __func__);
1155
1156 const int n_ctx = llama_n_ctx(ctx);
1157 const int n_batch = params.n_batch;
1158
1159 const int n_vocab = llama_vocab_n_tokens(vocab);
1160
1161 const int max_tasks_per_batch = 128;
1162 const int max_seq = std::min(a: 2*max_tasks_per_batch, b: (int) llama_n_seq_max(ctx));
1163
1164 llama_batch batch = llama_batch_init(n_tokens: n_ctx, embd: 0, n_seq_max: 2);
1165
1166 std::vector<float> tok_logits(n_vocab);
1167 // TODO: this could be made smaller; it's currently the worst-case size
1168 std::vector<float> batch_logits(size_t(n_ctx)*n_vocab);
1169
1170 std::vector<std::pair<size_t, llama_token>> eval_pairs;
1171 std::vector<float> eval_results;
1172 std::vector<std::thread> workers(std::thread::hardware_concurrency());
1173
1174 int n_correct = 0;
1175 int n_done = 0;
1176
1177 for (size_t i0 = 0; i0 < data.size(); i0++) {
1178 int n_cur = 0;
1179
1180 size_t i1 = i0;
1181 size_t i_logits = 0;
1182
1183 common_batch_clear(batch);
1184
1185 while (n_cur + (int) data[i1].required_tokens <= n_ctx) {
1186 int n_logits = 0;
1187 const int s0 = 2*(i1 - i0);
1188 if (s0 + 2 > max_seq) {
1189 break;
1190 }
1191
1192 for (size_t i = 0; i < data[i1].common_prefix; ++i) {
1193 common_batch_add(batch, id: data[i1].seq_tokens[0][i], pos: i, seq_ids: { s0 + 0, s0 + 1 }, logits: false);
1194 }
1195 batch.logits[batch.n_tokens - 1] = true;
1196 n_logits += 1;
1197
1198 for (int s = 0; s < 2; ++s) {
1199 // TODO: end before the last token, no need to predict past the end of the sequences
1200 for (size_t i = data[i1].common_prefix; i < data[i1].seq_tokens[s].size(); ++i) {
1201 common_batch_add(batch, id: data[i1].seq_tokens[s][i], pos: i, seq_ids: { s0 + s }, logits: true);
1202 n_logits += 1;
1203 }
1204 }
1205
1206 data[i1].i_logits = i_logits;
1207 i_logits += n_logits;
1208
1209 n_cur += data[i1].required_tokens;
1210 if (++i1 == data.size()) {
1211 break;
1212 }
1213 }
1214
1215 if (i0 == i1) {
1216 LOG_ERR("%s : task %zu does not fit in the context window (requires %lu tokens)\n", __func__, i0, data[i0].required_tokens);
1217 return;
1218 }
1219
1220 llama_memory_clear(mem: llama_get_memory(ctx), data: true);
1221
1222 // decode all tasks [i0, i1)
1223 if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) {
1224 LOG_ERR("%s: llama_decode() failed\n", __func__);
1225 return;
1226 }
1227
1228 eval_pairs.clear();
1229 for (size_t i = i0; i < i1; ++i) {
1230 auto & task = data[i];
1231
1232 const bool skip_choice =
1233 task.seq_tokens[0].size() - task.common_prefix > k_min_trailing_ctx &&
1234 task.seq_tokens[1].size() - task.common_prefix > k_min_trailing_ctx;
1235
1236 const auto& n_base1 = skip_choice ? task.n_base1 : task.common_prefix;
1237 const int last_1st = task.seq_tokens[0].size() - n_base1 > 1 ? 1 : 0;
1238 size_t li = n_base1 - task.common_prefix;
1239 for (size_t j = n_base1-1; j < task.seq_tokens[0].size()-1-last_1st; ++j) {
1240 eval_pairs.emplace_back(args: task.i_logits + li++, args&: task.seq_tokens[0][j+1]);
1241 }
1242 const auto& n_base2 = skip_choice ? task.n_base2 : task.common_prefix;
1243 const int last_2nd = task.seq_tokens[1].size() - n_base2 > 1 ? 1 : 0;
1244 // FIXME: this uses the wrong first logits when not skipping the choice word
1245 li = task.seq_tokens[0].size() - task.common_prefix + n_base2 - task.common_prefix;
1246 for (size_t j = n_base2-1; j < task.seq_tokens[1].size()-1-last_2nd; ++j) {
1247 eval_pairs.emplace_back(args: task.i_logits + li++, args&: task.seq_tokens[1][j+1]);
1248 }
1249 }
1250 compute_logprobs(batch_logits: batch_logits.data(), n_vocab, workers, eval_pairs, eval_results);
1251
1252 size_t ir = 0;
1253 for (size_t i = i0; i < i1; ++i) {
1254 auto & task = data[i];
1255
1256 const bool skip_choice =
1257 task.seq_tokens[0].size() - task.common_prefix > k_min_trailing_ctx &&
1258 task.seq_tokens[1].size() - task.common_prefix > k_min_trailing_ctx;
1259
1260 float score_1st = 0;
1261 const auto& n_base1 = skip_choice ? task.n_base1 : task.common_prefix;
1262 const int last_1st = task.seq_tokens[0].size() - n_base1 > 1 ? 1 : 0;
1263 for (size_t j = n_base1-1; j < task.seq_tokens[0].size()-1-last_1st; ++j) {
1264 score_1st += eval_results[ir++];
1265 }
1266 score_1st /= (task.seq_tokens[0].size() - n_base1 - last_1st);
1267
1268 float score_2nd = 0;
1269 const auto& n_base2 = skip_choice ? task.n_base2 : task.common_prefix;
1270 const int last_2nd = task.seq_tokens[1].size() - n_base2 > 1 ? 1 : 0;
1271 for (size_t j = n_base2-1; j < task.seq_tokens[1].size()-1-last_2nd; ++j) {
1272 score_2nd += eval_results[ir++];
1273 }
1274 score_2nd /= (task.seq_tokens[1].size() - n_base2 - last_2nd);
1275
1276 int result = score_1st > score_2nd ? 1 : 2;
1277
1278 if (result == task.answer) {
1279 ++n_correct;
1280 }
1281 ++n_done;
1282
1283 // print the accumulated accuracy mean x 100
1284 LOG("%zu\t%.4lf\t%10.6f %10.6f %d %d\n", i+1, 100.0 * n_correct/n_done, score_1st, score_2nd, result, task.answer);
1285 }
1286
1287 i0 = i1 - 1;
1288 }
1289
1290 LOG("\n");
1291
1292 if (n_done < 100) return;
1293
1294 const float p = 1.f*n_correct/n_done;
1295 const float sigma = 100.f*sqrt(x: p*(1-p)/(n_done-1));
1296
1297 LOG_INF("Final Winogrande score(%d tasks): %.4lf +/- %.4lf\n", n_done, 100*p, sigma);
1298}
1299
1300static bool deserialize_string(std::istream & in, std::string & str) {
1301 uint32_t size;
1302 if (!in.read(s: (char *)&size, n: sizeof(size)).fail()) {
1303 str.resize(n: size);
1304 if (!in.read(s: (char *)&str[0], n: size).fail()) return true;
1305 }
1306 return false;
1307}
1308
1309struct multiple_choice_answers {
1310 std::vector<std::string> answers;
1311 std::vector<int> labels;
1312 bool deserialize(std::istream& in) {
1313 uint32_t n;
1314 in.read(s: (char *)&n, n: sizeof(n));
1315 if (in.fail() || n > 100) return false; // 100 as max. number of answers should be good enough for any practical purpose
1316 answers.resize(new_size: n);
1317 labels.resize(new_size: n);
1318 for (auto& a : answers) {
1319 if (!deserialize_string(in, str&: a)) return false;
1320 }
1321 in.read(s: (char *)labels.data(), n: n*sizeof(int));
1322 return !in.fail();
1323 }
1324};
1325
1326struct multiple_choice_task {
1327 std::string question; // the question (or context that needs to be continued)
1328 multiple_choice_answers mc1; // possible answers (continuations) with a single correct answer
1329 multiple_choice_answers mc2; // possible answers (continuations) with multiple correct answers - not handled yet
1330 bool deserialize(std::istream& in) {
1331 if (!deserialize_string(in, str&: question)) return false;
1332 return mc1.deserialize(in) && mc2.deserialize(in);
1333 }
1334
1335 // For evaluation
1336 size_t i_logits; // starting index of logits in the llama_batch
1337 size_t common_prefix; // max number of initial tokens that are the same in all sentences
1338 size_t required_tokens; // needed number of tokens to evaluate all answers
1339 std::vector<std::vector<llama_token>> seq_tokens;
1340 std::vector<float> log_probs;
1341};
1342
1343static bool multiple_choice_prepare_one_task(llama_context * ctx, multiple_choice_task& task, bool log_error) {
1344 if (task.question.empty() || task.mc1.answers.empty()) {
1345 if (log_error) {
1346 LOG_ERR("%s: found bad task with empty question and/or answers\n", __func__);
1347 }
1348 return false;
1349 }
1350 task.seq_tokens.reserve(n: task.mc1.answers.size());
1351 for (auto& answer : task.mc1.answers) {
1352 if (answer.empty()) {
1353 if (log_error) {
1354 LOG_ERR("%s: found empty answer\n", __func__);
1355 }
1356 return false;
1357 }
1358 task.seq_tokens.emplace_back(args: ::common_tokenize(ctx, text: task.question + " " + answer, add_special: true));
1359 }
1360 auto min_len = task.seq_tokens.front().size();
1361 for (auto& seq : task.seq_tokens) {
1362 min_len = std::min(a: min_len, b: seq.size());
1363 }
1364 task.common_prefix = 0;
1365 for (size_t k = 0; k < min_len; ++k) {
1366 auto token = task.seq_tokens[0][k];
1367 bool all_same = true;
1368 for (size_t i = 1; i < task.seq_tokens.size(); ++i) {
1369 if (task.seq_tokens[i][k] != token) {
1370 all_same = false;
1371 break;
1372 }
1373 }
1374 if (!all_same) {
1375 break;
1376 }
1377 ++task.common_prefix;
1378 }
1379 task.required_tokens = task.common_prefix;
1380 for (auto& seq : task.seq_tokens) {
1381 task.required_tokens += seq.size() - task.common_prefix;
1382 }
1383 return true;
1384}
1385
1386//
1387// Calculates score for multiple choice tasks with single correct answer from prompt.
1388// Commonly used LLM evaluation metrics of this type are
1389// * ARC
1390// * HellaSwag
1391// * MMLU
1392// * TruthfulQA
1393//
1394// Validation datasets for these 4 tests can be found at
1395// https://huggingface.co/datasets/ikawrakow/validation-datasets-for-llama.cpp
1396// The data for these datasets was extracted from
1397// git@hf.co:datasets/allenai/ai2_arc
1398// https://github.com/rowanz/hellaswag/blob/master/data/hellaswag_val.jsonl
1399// git@hf.co:datasets/Stevross/mmlu
1400// https://huggingface.co/datasets/truthful_qa
1401//
1402static void multiple_choice_score(llama_context * ctx, const common_params & params) {
1403 const llama_model * model = llama_get_model(ctx);
1404 const llama_vocab * vocab = llama_model_get_vocab(model);
1405
1406 std::istringstream strstream(params.prompt);
1407 uint32_t n_task;
1408 strstream.read(s: (char *)&n_task, n: sizeof(n_task));
1409 if (strstream.fail() || n_task == 0) {
1410 LOG_ERR("%s: no tasks\n", __func__);
1411 return;
1412 }
1413 LOG_INF("%s: there are %u tasks in prompt\n", __func__, n_task);
1414 std::vector<uint32_t> task_pos(n_task);
1415 strstream.read(s: (char *)task_pos.data(), n: task_pos.size()*sizeof(uint32_t));
1416 if (strstream.fail()) {
1417 LOG_ERR("%s: failed to read task positions from prompt\n", __func__);
1418 return;
1419 }
1420
1421 std::vector<multiple_choice_task> tasks;
1422 if (params.multiple_choice_tasks == 0 || params.multiple_choice_tasks >= (size_t)n_task) {
1423 // Use all tasks
1424 tasks.resize(new_size: n_task);
1425 LOG_INF("%s: reading tasks", __func__);
1426 int n_dot = std::max(a: (int) n_task/100, b: 1);
1427 int i = 0;
1428 for (auto& task : tasks) {
1429 ++i;
1430 if (!task.deserialize(in&: strstream)) {
1431 LOG_ERR("%s: failed to read task %d of %u\n", __func__, i, n_task);
1432 return;
1433 }
1434 if (i%n_dot == 0) LOG(".");
1435 }
1436 LOG("done\n");
1437 }
1438 else {
1439 LOG_INF("%s: selecting %zu random tasks from %u tasks available\n", __func__, params.multiple_choice_tasks, n_task);
1440 std::mt19937 rng(1);
1441 std::vector<int> aux(n_task);
1442 for (uint32_t i = 0; i < n_task; ++i) aux[i] = i;
1443 float scale = 1.f/(1.f + (float)std::mt19937::max());
1444 tasks.resize(new_size: params.multiple_choice_tasks);
1445 for (auto& task : tasks) {
1446 int j = (int)(scale * rng() * aux.size());
1447 int idx = aux[j];
1448 aux[j] = aux.back();
1449 aux.pop_back();
1450 strstream.seekg(task_pos[idx], std::ios::beg);
1451 if (!task.deserialize(in&: strstream)) {
1452 LOG_ERR("%s: failed to read task %d at position %u\n", __func__, idx, task_pos[idx]);
1453 return;
1454 }
1455 }
1456 n_task = params.multiple_choice_tasks;
1457 }
1458
1459 LOG_INF("%s: preparing task data", __func__);
1460 if (n_task > 500) {
1461 LOG("...");
1462 std::atomic<int> counter(0);
1463 std::atomic<int> n_bad(0);
1464 auto prepare = [&counter, &n_bad, &tasks, ctx] () {
1465 int num_tasks = tasks.size();
1466 int n_bad_local = 0;
1467 while (true) {
1468 int first = counter.fetch_add(K_TOKEN_CHUNK);
1469 if (first >= num_tasks) {
1470 if (n_bad_local > 0) n_bad += n_bad_local;
1471 break;
1472 }
1473 int last = std::min(a: first + K_TOKEN_CHUNK, b: num_tasks);
1474 for (int i = first; i < last; ++i) {
1475 if (!multiple_choice_prepare_one_task(ctx, task&: tasks[i], log_error: false)) ++n_bad_local;
1476 }
1477 }
1478 };
1479 size_t max_thread = std::thread::hardware_concurrency();
1480 max_thread = std::min(a: max_thread, b: (tasks.size() + K_TOKEN_CHUNK - 1)/K_TOKEN_CHUNK);
1481 std::vector<std::thread> workers(max_thread-1);
1482 for (auto& w : workers) w = std::thread(prepare);
1483 prepare();
1484 for (auto& w : workers) w.join();
1485 LOG("done\n");
1486 int nbad = n_bad;
1487 if (nbad > 0) {
1488 LOG_ERR("%s: found %d malformed tasks\n", __func__, nbad);
1489 return;
1490 }
1491 } else {
1492 int n_dot = std::max(a: (int) n_task/100, b: 1);
1493 int i_task = 0;
1494 for (auto& task : tasks) {
1495 ++i_task;
1496 if (!multiple_choice_prepare_one_task(ctx, task, log_error: true)) {
1497 return;
1498 }
1499 if (i_task%n_dot == 0) {
1500 LOG(".");
1501 }
1502 }
1503 LOG("done\n");
1504 }
1505
1506 LOG_INF("%s : calculating TruthfulQA score over %zu tasks.\n", __func__, tasks.size());
1507
1508 LOG("\ntask\tacc_norm\n");
1509
1510 const int n_ctx = llama_n_ctx(ctx);
1511 const int n_batch = params.n_batch;
1512
1513 const int n_vocab = llama_vocab_n_tokens(vocab);
1514
1515 const int max_tasks_per_batch = 32;
1516 const int max_seq = std::min(a: 4*max_tasks_per_batch, b: (int) llama_n_seq_max(ctx));
1517
1518 llama_batch batch = llama_batch_init(n_tokens: n_ctx, embd: 0, n_seq_max: max_seq);
1519
1520 std::vector<float> tok_logits(n_vocab);
1521 std::vector<float> batch_logits(size_t(n_ctx)*n_vocab);
1522
1523 std::vector<std::pair<size_t, llama_token>> eval_pairs;
1524 std::vector<float> eval_results;
1525 std::vector<std::thread> workers(std::thread::hardware_concurrency());
1526 std::vector<int> batch_indeces;
1527
1528 int n_done = 0;
1529 int n_correct = 0;
1530 int n_tot_answers = 0;
1531
1532 for (size_t i0 = 0; i0 < tasks.size(); i0++) {
1533 int n_cur = 0;
1534
1535 size_t i1 = i0;
1536 size_t i_logits = 0; // this tells us how many logits were needed before this point in the batch
1537
1538 common_batch_clear(batch);
1539
1540 // batch as much tasks as possible into the available context
1541 // each task has 4 unique sequence ids - one for each ending
1542 // the common prefix is shared among the 4 sequences to save tokens
1543 // we extract logits only from the last common token and from all ending tokens of each sequence
1544 int s0 = 0;
1545 while (n_cur + (int) tasks[i1].required_tokens <= n_ctx) {
1546 auto& cur_task = tasks[i1];
1547 int n_logits = 0;
1548
1549 int num_answers = cur_task.seq_tokens.size();
1550 if (s0 + num_answers > max_seq) {
1551 if (s0 == 0) {
1552 LOG_ERR("%s : task %zu requires a higher -np|--parallel value (at least %d)\n", __func__, i0, num_answers);
1553 return;
1554 }
1555 break;
1556 }
1557
1558 if (int(batch_indeces.size()) != num_answers) {
1559 batch_indeces.resize(new_size: num_answers);
1560 }
1561
1562 for (int s = 0; s < num_answers; ++s) {
1563 batch_indeces[s] = s0 + s;
1564 }
1565
1566 for (size_t i = 0; i < cur_task.common_prefix; ++i) {
1567 //llama_batch_add(batch, cur_task.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3}, false);
1568 common_batch_add(batch, id: cur_task.seq_tokens[0][i], pos: i, seq_ids: batch_indeces, logits: false);
1569 }
1570 batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix
1571 n_logits += 1;
1572
1573 for (int s = 0; s < int(cur_task.seq_tokens.size()); ++s) {
1574 const size_t seq_tokens_size = cur_task.seq_tokens[s].size();
1575 // TODO: don't evaluate the last token of each sequence
1576 for (size_t i = cur_task.common_prefix; i < seq_tokens_size; ++i) {
1577 const bool needs_logits = i < seq_tokens_size - 1;
1578 common_batch_add(batch, id: cur_task.seq_tokens[s][i], pos: i, seq_ids: { s0 + s }, logits: needs_logits);
1579 n_logits += needs_logits;
1580 }
1581 }
1582
1583 s0 += num_answers;
1584
1585 cur_task.i_logits = i_logits;
1586 i_logits += n_logits;
1587
1588 n_cur += cur_task.required_tokens;
1589 if (++i1 == tasks.size()) {
1590 break;
1591 }
1592 }
1593
1594 if (i0 == i1) {
1595 LOG_ERR("%s : task %zu does not fit in the context window (requires %lu tokens)\n", __func__, i0, tasks[i0].required_tokens);
1596 return;
1597 }
1598
1599 llama_memory_clear(mem: llama_get_memory(ctx), data: true);
1600
1601 // decode all tasks [i0, i1)
1602 if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) {
1603 LOG_ERR("%s: llama_decode() failed\n", __func__);
1604 return;
1605 }
1606
1607 // Compute log-probs in parallel
1608 // First we collect all tasks
1609 eval_pairs.clear();
1610 for (size_t i = i0; i < i1; ++i) {
1611 auto& cur_task = tasks[i];
1612 size_t li = 1; // skip the last logit of the common prefix (computed separately below)
1613 for (int s = 0; s < int(cur_task.seq_tokens.size()); ++s) {
1614 for (size_t j = cur_task.common_prefix; j < cur_task.seq_tokens[s].size() - 1; j++) {
1615 eval_pairs.emplace_back(args: cur_task.i_logits + li++, args&: cur_task.seq_tokens[s][j + 1]);
1616 }
1617 }
1618 }
1619 // Then we do the actual calculation
1620 compute_logprobs(batch_logits: batch_logits.data(), n_vocab, workers, eval_pairs, eval_results);
1621
1622 size_t ir = 0;
1623
1624 // compute the logprobs for each ending of the decoded tasks
1625 for (size_t i = i0; i < i1; ++i) {
1626 auto & cur_task = tasks[i];
1627 //LOG("==== Evaluating <%s> with correct answer ", cur_task.question.c_str());
1628 //for (int j = 0; j < int(cur_task.mc1.labels.size()); ++j) {
1629 // if (cur_task.mc1.labels[j] == 1) {
1630 // LOG("%d", j+1);
1631 // }
1632 //}
1633 //LOG("\n common_prefix: %zu\n", cur_task.common_prefix);
1634
1635 // get the logits of the last token of the common prefix
1636 std::memcpy(dest: tok_logits.data(), src: batch_logits.data() + cur_task.i_logits*n_vocab, n: n_vocab*sizeof(float));
1637
1638 const auto first_probs = softmax(logits: tok_logits);
1639
1640 cur_task.log_probs.resize(new_size: cur_task.seq_tokens.size());
1641 for (int s = 0; s < int(cur_task.seq_tokens.size()); ++s) {
1642 size_t count = 1;
1643 float log_prob = std::log(x: first_probs[cur_task.seq_tokens[s][cur_task.common_prefix]]);
1644 for (size_t j = cur_task.common_prefix; j < cur_task.seq_tokens[s].size() - 1; j++) {
1645 //LOG(" %zu %g\n", ir, eval_results[ir]);
1646 ++count;
1647 log_prob += eval_results[ir++];
1648 }
1649 cur_task.log_probs[s] = log_prob / count;
1650 //LOG(" Final: %g\n", log_prob / count);
1651 //LOG(" <%s> : %g\n", cur_task.mc1.answers[s].c_str(), log_prob/count);
1652 }
1653
1654 // Find the ending with maximum logprob
1655 size_t logprob_max_idx = 0;
1656 float logprob_max_val = cur_task.log_probs[0];
1657 for (size_t s = 1; s < cur_task.log_probs.size(); s++) {
1658 if (cur_task.log_probs[s] > logprob_max_val) {
1659 logprob_max_val = cur_task.log_probs[s];
1660 logprob_max_idx = s;
1661 }
1662 }
1663
1664 n_tot_answers += cur_task.log_probs.size();
1665 if (cur_task.mc1.labels[logprob_max_idx] == 1) {
1666 ++n_correct;
1667 }
1668 ++n_done;
1669
1670 // Print the accumulated accuracy mean x 100
1671 LOG("%d\t%.8lf\n", n_done, 100.*n_correct/n_done);
1672 }
1673
1674 i0 = i1 - 1;
1675 }
1676
1677 llama_batch_free(batch);
1678
1679 if (n_done < 100 && (params.multiple_choice_tasks != 0 && params.multiple_choice_tasks < (size_t)n_task)) return;
1680
1681 float p = 1.f*n_correct/n_done;
1682 float sigma = sqrt(x: p*(1-p)/(n_done-1));
1683 LOG("\n");
1684 LOG_INF("Final result: %.4f +/- %.4f\n", 100.f*p, 100.f*sigma);
1685 p = 1.f*n_done/n_tot_answers;
1686 sigma = sqrt(x: p*(1-p)/(n_done-1));
1687 LOG_INF("Random chance: %.4f +/- %.4f\n", 100.f*p, 100.f*sigma);
1688
1689 LOG_INF("\n");
1690}
1691
1692static void kl_divergence(llama_context * ctx, const common_params & params) {
1693 const llama_model * model = llama_get_model(ctx);
1694 const llama_vocab * vocab = llama_model_get_vocab(model);
1695
1696 if (params.logits_file.empty()) {
1697 LOG_ERR("%s: you must provide a name of a file containing the log probabilities of the base model\n", __func__);
1698 return;
1699 }
1700 std::ifstream in(params.logits_file.c_str(), std::ios::binary);
1701 if (!in) {
1702 LOG_ERR("%s: failed to open %s\n", __func__, params.logits_file.c_str());
1703 return;
1704 }
1705 {
1706 char check[9]; check[8] = 0;
1707 in.read(s: check, n: 8);
1708 if (in.fail() || strncmp(s1: "_logits_", s2: check, n: 8) != 0) {
1709 LOG_ERR("%s: %s does not look like a file containing log-probabilities\n", __func__, params.logits_file.c_str());
1710 return;
1711 }
1712 }
1713
1714 uint32_t n_ctx;
1715 in.read(s: (char *)&n_ctx, n: sizeof(n_ctx));
1716 if (n_ctx > llama_n_ctx(ctx)) {
1717 LOG_ERR("%s: %s has been computed with %u, while the current context is %d. Increase it with -c and retry\n",
1718 __func__, params.logits_file.c_str(), n_ctx, params.n_ctx);
1719 }
1720
1721 int n_vocab;
1722 int n_chunk;
1723 in.read(s: (char *)&n_vocab, n: sizeof(n_vocab));
1724 in.read(s: (char *)&n_chunk, n: sizeof(n_chunk));
1725 if (in.fail()) {
1726 LOG_ERR("%s: failed reading n_vocab, n_chunk from %s\n", __func__, params.logits_file.c_str());
1727 return;
1728 }
1729 if (n_vocab != llama_vocab_n_tokens(vocab)) {
1730 LOG_ERR("%s: inconsistent vocabulary (%d vs %d)\n", __func__, n_vocab, llama_vocab_n_tokens(vocab));
1731 }
1732
1733 std::vector<llama_token> tokens(size_t(n_ctx) * n_chunk);
1734 if (in.read(s: (char *)tokens.data(), n: tokens.size()*sizeof(tokens[0])).fail()) {
1735 LOG_ERR("%s: failed reading evaluation tokens from %s\n", __func__, params.logits_file.c_str());
1736 return;
1737 }
1738
1739 const int n_batch = params.n_batch;
1740 const int num_batches = (n_ctx + n_batch - 1)/n_batch;
1741 const int nv = 2*((n_vocab + 1)/2) + 4;
1742 const bool add_bos = llama_vocab_get_add_bos(vocab);
1743 GGML_ASSERT(!llama_vocab_get_add_eos(vocab));
1744
1745 std::vector<uint16_t> log_probs_uint16(size_t(n_ctx - 1 - n_ctx/2) * nv);
1746 std::vector<float> kld_values(size_t(n_ctx - 1 - n_ctx/2)*n_chunk);
1747 std::vector<float> p_diff_values(size_t(n_ctx - 1 - n_ctx/2)*n_chunk);
1748 std::vector<float> logits;
1749 if (num_batches > 1) {
1750 logits.reserve(n: size_t(n_ctx) * n_vocab);
1751 }
1752
1753 std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1);
1754
1755 auto mean_and_uncertainty = [] (double sum, double sum2, size_t count) {
1756 if (count < 1) {
1757 return std::make_pair(x: 0., y: 0.);
1758 }
1759 double f = sum/count;
1760 double df = sum2/count - f*f;
1761 df = df > 0 && count > 10 ? sqrt(x: df/(count-1)) : 0.;
1762 return std::make_pair(x&: f, y&: df);
1763 };
1764 auto covariance = [] (double suma, double sumb, double sumab, size_t count) {
1765 if (count < 10) {
1766 return 0.0;
1767 }
1768 double var = sumab/count - (suma/count)*(sumb/count);
1769 var /= count - 1;
1770 return var;
1771 };
1772
1773 kl_divergence_result kld;
1774 auto kld_ptr = kld_values.data();
1775 auto p_diff_ptr = p_diff_values.data();
1776
1777 for (int i = 0; i < n_chunk; ++i) {
1778 const int start = i * n_ctx;
1779 const int end = start + n_ctx;
1780
1781 const auto t_start = std::chrono::high_resolution_clock::now();
1782
1783 if (in.read(s: (char *)log_probs_uint16.data(), n: log_probs_uint16.size()*sizeof(uint16_t)).fail()) {
1784 LOG_ERR("%s: failed reading log-probs for chunk %d\n", __func__, i);
1785 return;
1786 }
1787
1788 // clear the KV cache
1789 llama_memory_clear(mem: llama_get_memory(ctx), data: true);
1790
1791 llama_batch batch = llama_batch_init(n_tokens: n_batch, embd: 0, n_seq_max: 1);
1792
1793 for (int j = 0; j < num_batches; ++j) {
1794 const int batch_start = start + j * n_batch;
1795 const int batch_size = std::min(a: end - batch_start, b: n_batch);
1796
1797 // save original token and restore it after eval
1798 const auto token_org = tokens[batch_start];
1799
1800 // add BOS token for the first batch of each chunk
1801 if (add_bos && j == 0) {
1802 tokens[batch_start] = llama_vocab_bos(vocab);
1803 }
1804
1805 common_batch_clear(batch);
1806 for (int i = 0; i < batch_size; i++) {
1807 common_batch_add(batch, id: tokens[batch_start + i], pos: j*n_batch + i, seq_ids: {0}, logits: true);
1808 }
1809
1810 if (llama_decode(ctx, batch)) {
1811 LOG_ERR("%s : failed to eval\n", __func__);
1812 llama_batch_free(batch);
1813 return;
1814 }
1815
1816 // restore the original token in case it was set to BOS
1817 tokens[batch_start] = token_org;
1818
1819 if (num_batches > 1) {
1820 const auto * batch_logits = llama_get_logits(ctx);
1821 logits.insert(position: logits.end(), first: batch_logits, last: batch_logits + size_t(batch_size) * n_vocab);
1822 }
1823 }
1824
1825 llama_batch_free(batch);
1826
1827 const auto t_end = std::chrono::high_resolution_clock::now();
1828
1829 if (i == 0) {
1830 const float t_total = std::chrono::duration<float>(t_end - t_start).count();
1831 LOG_INF("%s: %.2f seconds per pass - ETA ", __func__, t_total);
1832 int total_seconds = (int)(t_total * n_chunk);
1833 if (total_seconds >= 60*60) {
1834 LOG("%d hours ", total_seconds / (60*60));
1835 total_seconds = total_seconds % (60*60);
1836 }
1837 LOG("%.2f minutes\n", total_seconds / 60.0);
1838 }
1839 LOG("\n");
1840 LOG("chunk PPL ln(PPL(Q)/PPL(base)) KL Divergence Δp RMS Same top p\n");
1841
1842 const int first = n_ctx/2;
1843 const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits(ctx);
1844 process_logits(n_vocab, logits: all_logits + size_t(first)*n_vocab, tokens: tokens.data() + start + first, n_token: n_ctx - 1 - first,
1845 workers, base_log_probs: log_probs_uint16, kld, kld_values: kld_ptr, p_diff_values: p_diff_ptr);
1846 p_diff_ptr += n_ctx - 1 - first;
1847 kld_ptr += n_ctx - 1 - first;
1848
1849 LOG("%4d", i+1);
1850
1851 auto log_ppl = mean_and_uncertainty(kld.sum_nll, kld.sum_nll2, kld.count);
1852 const double ppl_val = exp(x: log_ppl.first);
1853 const double ppl_unc = ppl_val * log_ppl.second; // ppl_unc = sqrt( (dexp(x) / dx) ** 2 * log_ppl.second ** 2 )
1854 LOG(" %9.4lf ± %9.4lf", ppl_val, ppl_unc);
1855
1856 auto log_ppl_base = mean_and_uncertainty(kld.sum_nll_base, kld.sum_nll_base2, kld.count);
1857 const double log_ppl_cov = covariance(kld.sum_nll, kld.sum_nll_base, kld.sum_nll_nll_base, kld.count);
1858 const double log_ppl_ratio_val = log_ppl.first - log_ppl_base.first;
1859 const double log_ppl_ratio_unc = sqrt(x: log_ppl.second*log_ppl.second + log_ppl_base.second*log_ppl_base.second - 2.0*log_ppl_cov);
1860 LOG(" %10.5lf ± %10.5lf", log_ppl_ratio_val, log_ppl_ratio_unc);
1861
1862 auto kl_div = mean_and_uncertainty(kld.sum_kld, kld.sum_kld2, kld.count);
1863 LOG(" %10.5lf ± %10.5lf", kl_div.first, kl_div.second);
1864
1865 auto p_diff_mse = mean_and_uncertainty(kld.sum_p_diff2, kld.sum_p_diff4, kld.count);
1866 const double p_diff_rms_val = sqrt(x: p_diff_mse.first);
1867 const double p_diff_rms_unc = 0.5/p_diff_rms_val * p_diff_mse.second;
1868 LOG(" %6.3lf ± %6.3lf %%", 100.0*p_diff_rms_val, 100.0*p_diff_rms_unc);
1869
1870 double p_top_val = 1.*kld.n_same_top/kld.count;
1871 double p_top_unc = sqrt(x: p_top_val*(1 - p_top_val)/(kld.count - 1));
1872 LOG(" %6.3lf ± %6.3lf %%", 100.0*p_top_val, 100.0*p_top_unc);
1873
1874 LOG("\n");
1875
1876 logits.clear();
1877 }
1878 LOG("\n");
1879
1880 if (kld.count < 100) return; // we do not wish to do statistics on so few values
1881
1882 std::sort(first: kld_values.begin(), last: kld_values.end());
1883 std::sort(first: p_diff_values.begin(), last: p_diff_values.end());
1884
1885 LOG("====== Perplexity statistics ======\n");
1886
1887 auto log_ppl = mean_and_uncertainty(kld.sum_nll, kld.sum_nll2, kld.count);
1888 const double ppl_val = exp(x: log_ppl.first);
1889 const double ppl_unc = ppl_val * log_ppl.second; // ppl_unc = sqrt( (dexp(x) / dx) ** 2 * log_ppl.second ** 2 )
1890 LOG("Mean PPL(Q) : %10.6lf ± %10.6lf\n", ppl_val, ppl_unc);
1891
1892 auto log_ppl_base = mean_and_uncertainty(kld.sum_nll_base, kld.sum_nll_base2, kld.count);
1893 const double ppl_base_val = exp(x: log_ppl_base.first);
1894 const double ppl_base_unc = ppl_base_val * log_ppl_base.second; // ppl_base_unc = sqrt( (dexp(x) / dx) ** 2 * log_ppl_base.second ** 2 )
1895 LOG("Mean PPL(base) : %10.6lf ± %10.6lf\n", ppl_base_val, ppl_base_unc);
1896
1897 const double log_ppl_cov = covariance(kld.sum_nll, kld.sum_nll_base, kld.sum_nll_nll_base, kld.count);
1898 // LOG("Cov(ln(PPL(Q)), ln(PPL(base))): %10.6lf\n", log_ppl_cov);
1899 const double log_ppl_cor = log_ppl_cov / (log_ppl.second*log_ppl_base.second);
1900 LOG("Cor(ln(PPL(Q)), ln(PPL(base))): %6.2lf%%\n", 100.0*log_ppl_cor);
1901
1902 const double log_ppl_ratio_val = log_ppl.first - log_ppl_base.first;
1903 const double log_ppl_ratio_unc = sqrt(x: log_ppl.second*log_ppl.second + log_ppl_base.second*log_ppl_base.second - 2.0*log_ppl_cov);
1904 LOG("Mean ln(PPL(Q)/PPL(base)) : %10.6lf ± %10.6lf\n", log_ppl_ratio_val, log_ppl_ratio_unc);
1905
1906 const double ppl_ratio_val = exp(x: log_ppl_ratio_val);
1907 const double ppl_ratio_unc = ppl_ratio_val * log_ppl_ratio_unc; // ppl_ratio_unc = sqrt( (dexp(x) / dx) ** 2 * log_ppl_ratio.second ** 2 )
1908 LOG("Mean PPL(Q)/PPL(base) : %10.6lf ± %10.6lf\n", ppl_ratio_val, ppl_ratio_unc);
1909
1910 const double ppl_cov = ppl_val * ppl_base_val * log_ppl_cov;
1911 const double ppl_diff_val = ppl_val - ppl_base_val;
1912 const double ppl_diff_unc = sqrt(x: ppl_unc*ppl_unc + ppl_base_unc*ppl_base_unc - 2.0*ppl_cov);
1913 LOG("Mean PPL(Q)-PPL(base) : %10.6lf ± %10.6lf\n", ppl_diff_val, ppl_diff_unc);
1914
1915 LOG("\n");
1916
1917 LOG("====== KL divergence statistics ======\n");
1918 auto kl_div = mean_and_uncertainty(kld.sum_kld, kld.sum_kld2, kld.count);
1919 LOG("Mean KLD: %10.6lf ± %10.6lf\n", kl_div.first, kl_div.second);
1920 auto kld_median = kld_values.size()%2 == 0 ? 0.5f*(kld_values[kld_values.size()/2] + kld_values[kld_values.size()/2-1])
1921 : kld_values[kld_values.size()/2];
1922
1923 auto percentile = [] (std::vector<float> values, float fraction) {
1924 if (fraction <= 0) return values.front();
1925 if (fraction >= 1) return values.back();
1926 float p = fraction*(values.size() - 1);
1927 size_t ip = size_t(p); p -= ip;
1928 return (1 - p)*values[ip] + p*values[std::min(a: ip+1, b: values.size()-1)];
1929 };
1930
1931 LOG("Maximum KLD: %10.6f\n", kld_values.back());
1932 LOG("99.9%% KLD: %10.6f\n", percentile(kld_values, 0.999f));
1933 LOG("99.0%% KLD: %10.6f\n", percentile(kld_values, 0.990f));
1934 LOG("95.0%% KLD: %10.6f\n", percentile(kld_values, 0.950f));
1935 LOG("90.0%% KLD: %10.6f\n", percentile(kld_values, 0.900f));
1936 LOG("Median KLD: %10.6f\n", kld_median);
1937 LOG("10.0%% KLD: %10.6f\n", percentile(kld_values, 0.100f));
1938 LOG(" 5.0%% KLD: %10.6f\n", percentile(kld_values, 0.050f));
1939 LOG(" 1.0%% KLD: %10.6f\n", percentile(kld_values, 0.010f));
1940 LOG(" 0.1%% KLD: %10.6f\n", percentile(kld_values, 0.001f));
1941 LOG("Minimum KLD: %10.6f\n", kld_values.front());
1942
1943 LOG("\n");
1944
1945 LOG("====== Token probability statistics ======\n");
1946
1947 auto p_diff = mean_and_uncertainty(kld.sum_p_diff, kld.sum_p_diff2, kld.count);
1948 LOG("Mean Δp: %6.3lf ± %5.3lf %%\n", 100.0*p_diff.first, 100.0*p_diff.second);
1949
1950 auto p_diff_median = p_diff_values.size()%2 == 0 ? 0.5f*(p_diff_values[p_diff_values.size()/2] + p_diff_values[p_diff_values.size()/2-1])
1951 : p_diff_values[p_diff_values.size()/2];
1952
1953 LOG("Maximum Δp: %6.3lf%%\n", 100.0*p_diff_values.back());
1954 LOG("99.9%% Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.999f));
1955 LOG("99.0%% Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.990f));
1956 LOG("95.0%% Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.950f));
1957 LOG("90.0%% Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.900f));
1958 LOG("75.0%% Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.750f));
1959 LOG("Median Δp: %6.3lf%%\n", 100.0*p_diff_median);
1960 LOG("25.0%% Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.250f));
1961 LOG("10.0%% Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.100f));
1962 LOG(" 5.0%% Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.050f));
1963 LOG(" 1.0%% Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.010f));
1964 LOG(" 0.1%% Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.001f));
1965 LOG("Minimum Δp: %6.3lf%%\n", 100.0*p_diff_values.front());
1966
1967 auto p_diff_mse = mean_and_uncertainty(kld.sum_p_diff2, kld.sum_p_diff4, kld.count);
1968 // LOG("MSE Δp : %10.6lf ± %10.6lf\n", p_diff_mse.first, p_diff_mse.second);
1969
1970 const double p_diff_rms_val = sqrt(x: p_diff_mse.first);
1971 const double p_diff_rms_unc = 0.5/p_diff_rms_val * p_diff_mse.second;
1972 LOG("RMS Δp : %6.3lf ± %5.3lf %%\n", 100.0*p_diff_rms_val, 100.0*p_diff_rms_unc);
1973
1974 const double same_top_p = 1.0*kld.n_same_top/kld.count;
1975 LOG("Same top p: %6.3lf ± %5.3lf %%\n", 100.0*same_top_p, 100.0*sqrt(same_top_p*(1.0 - same_top_p)/(kld.count - 1)));
1976}
1977
1978int main(int argc, char ** argv) {
1979 common_params params;
1980
1981 params.n_ctx = 512;
1982 params.escape = false;
1983
1984 if (!common_params_parse(argc, argv, params, ex: LLAMA_EXAMPLE_PERPLEXITY)) {
1985 return 1;
1986 }
1987
1988 common_init();
1989
1990 const int32_t n_ctx = params.n_ctx;
1991
1992 if (n_ctx <= 0) {
1993 LOG_ERR("%s: perplexity tool requires '--ctx-size' > 0\n", __func__);
1994 return 1;
1995 }
1996
1997 const bool ppl = !params.hellaswag && !params.winogrande && !params.multiple_choice && !params.kl_divergence;
1998
1999 if (ppl) {
2000 const int32_t n_seq = std::max(a: 1, b: params.n_batch / n_ctx);
2001 const int32_t n_kv = n_seq * n_ctx;
2002
2003 params.n_parallel = n_seq;
2004 params.n_ctx = n_kv;
2005
2006 params.n_batch = std::min(a: params.n_batch, b: n_kv);
2007 } else {
2008 params.n_batch = std::min(a: params.n_batch, b: params.n_ctx);
2009 if (params.kl_divergence) {
2010 params.n_parallel = 1;
2011 } else {
2012 // ensure there's at least enough seq_ids for HellaSwag
2013 params.n_parallel = std::max(a: 4, b: params.n_parallel);
2014 }
2015 }
2016
2017 if (params.ppl_stride > 0) {
2018 LOG_INF("Will perform strided perplexity calculation -> adjusting context size from %d to %d\n",
2019 params.n_ctx, params.n_ctx + params.ppl_stride/2);
2020 params.n_ctx += params.ppl_stride/2;
2021 }
2022
2023 llama_backend_init();
2024 llama_numa_init(numa: params.numa);
2025
2026 // load the model and apply lora adapter, if any
2027 common_init_result llama_init = common_init_from_params(params);
2028
2029 llama_model * model = llama_init.model.get();
2030 llama_context * ctx = llama_init.context.get();
2031
2032 if (model == NULL) {
2033 LOG_ERR("%s: unable to load model\n", __func__);
2034 return 1;
2035 }
2036
2037 const int n_ctx_train = llama_model_n_ctx_train(model);
2038
2039 if (params.n_ctx > n_ctx_train) {
2040 LOG_WRN("%s: model was trained on only %d context tokens (%d specified)\n",
2041 __func__, n_ctx_train, params.n_ctx);
2042 }
2043
2044 // print system information
2045 {
2046 LOG_INF("\n");
2047 LOG_INF("%s\n", common_params_get_system_info(params).c_str());
2048 }
2049
2050 struct results_perplexity results;
2051 if (params.hellaswag) {
2052 hellaswag_score(ctx, params);
2053 } else if (params.winogrande) {
2054 winogrande_score(ctx, params);
2055 } else if (params.multiple_choice) {
2056 multiple_choice_score(ctx, params);
2057 } else if (params.kl_divergence) {
2058 kl_divergence(ctx, params);
2059 } else {
2060 results = perplexity(ctx, params, n_ctx);
2061 }
2062
2063 LOG("\n");
2064 llama_perf_context_print(ctx);
2065 llama_memory_breakdown_print(ctx);
2066
2067 llama_backend_free();
2068
2069 return 0;
2070}
2071