| 1 | #include "arg.h" |
| 2 | #include "common.h" |
| 3 | #include "log.h" |
| 4 | #include "llama.h" |
| 5 | |
| 6 | #include <chrono> |
| 7 | #include <algorithm> |
| 8 | #include <array> |
| 9 | #include <atomic> |
| 10 | #include <cmath> |
| 11 | #include <cstdio> |
| 12 | #include <cstring> |
| 13 | #include <ctime> |
| 14 | #include <fstream> |
| 15 | #include <mutex> |
| 16 | #include <random> |
| 17 | #include <sstream> |
| 18 | #include <thread> |
| 19 | #include <vector> |
| 20 | |
| 21 | #if defined(_MSC_VER) |
| 22 | #pragma warning(disable: 4244 4267) // possible loss of data |
| 23 | #endif |
| 24 | |
| 25 | struct results_perplexity { |
| 26 | std::vector<llama_token> tokens; |
| 27 | double ppl_value; |
| 28 | std::vector<float> logits; |
| 29 | std::vector<float> probs; |
| 30 | }; |
| 31 | |
| 32 | struct results_log_softmax { |
| 33 | double log_softmax; |
| 34 | float logit; |
| 35 | float prob; |
| 36 | }; |
| 37 | |
| 38 | static std::vector<float> softmax(const std::vector<float>& logits) { |
| 39 | std::vector<float> probs(logits.size()); |
| 40 | float max_logit = logits[0]; |
| 41 | for (float v : logits) { |
| 42 | max_logit = std::max(a: max_logit, b: v); |
| 43 | } |
| 44 | double sum_exp = 0.0; |
| 45 | for (size_t i = 0; i < logits.size(); i++) { |
| 46 | // Subtract the maximum logit value from the current logit value for numerical stability |
| 47 | const float logit = logits[i] - max_logit; |
| 48 | const float exp_logit = expf(x: logit); |
| 49 | sum_exp += exp_logit; |
| 50 | probs[i] = exp_logit; |
| 51 | } |
| 52 | for (size_t i = 0; i < probs.size(); i++) { |
| 53 | probs[i] /= sum_exp; |
| 54 | } |
| 55 | return probs; |
| 56 | } |
| 57 | |
| 58 | static results_log_softmax log_softmax(int n_vocab, const float * logits, int tok) { |
| 59 | float max_logit = logits[0]; |
| 60 | for (int i = 1; i < n_vocab; ++i) { |
| 61 | max_logit = std::max(a: max_logit, b: logits[i]); |
| 62 | } |
| 63 | double sum_exp = 0.0; |
| 64 | for (int i = 0; i < n_vocab; ++i) { |
| 65 | sum_exp += expf(x: logits[i] - max_logit); |
| 66 | } |
| 67 | return {.log_softmax: logits[tok] - max_logit - log(x: sum_exp), .logit: logits[tok], .prob: expf(x: logits[tok] - max_logit) / (float) sum_exp}; |
| 68 | } |
| 69 | |
| 70 | static inline int nearest_int(float fval) { |
| 71 | //assert(fval <= 4194303.f); |
| 72 | float val = fval + 12582912.f; |
| 73 | int i; memcpy(dest: &i, src: &val, n: sizeof(int)); |
| 74 | return (i & 0x007fffff) - 0x00400000; |
| 75 | } |
| 76 | |
| 77 | static double log_softmax(int n_vocab, const float * logits, uint16_t * log_prob, int tok) { |
| 78 | float max_logit = logits[0]; |
| 79 | float min_logit = logits[0]; |
| 80 | for (int i = 1; i < n_vocab; ++i) { |
| 81 | max_logit = std::max(a: max_logit, b: logits[i]); |
| 82 | min_logit = std::min(a: min_logit, b: logits[i]); |
| 83 | } |
| 84 | min_logit = std::max(a: min_logit, b: max_logit - 16); |
| 85 | double sum_exp = 0.0; |
| 86 | for (int i = 0; i < n_vocab; ++i) { |
| 87 | sum_exp += expf(x: logits[i] - max_logit); |
| 88 | } |
| 89 | const float log_sum_exp = log(x: sum_exp); |
| 90 | const float min_log_prob = min_logit - max_logit - log_sum_exp; |
| 91 | const float scale = (max_logit - min_logit)/65535.f; |
| 92 | float * d = (float *)log_prob; |
| 93 | d[0] = scale; |
| 94 | d[1] = min_log_prob; |
| 95 | log_prob += 4; |
| 96 | if (scale) { |
| 97 | const float inv_scale = 1/scale; |
| 98 | for (int i = 0; i < n_vocab; ++i) { |
| 99 | log_prob[i] = logits[i] > min_logit ? nearest_int(fval: inv_scale*(logits[i] - min_logit)) : 0; |
| 100 | } |
| 101 | } else { |
| 102 | std::memset(s: log_prob, c: 0, n: n_vocab*sizeof(uint16_t)); |
| 103 | } |
| 104 | return max_logit + log_sum_exp - logits[tok]; |
| 105 | } |
| 106 | |
| 107 | static void process_logits( |
| 108 | int n_vocab, const float * logits, const int * tokens, int n_token, std::vector<std::thread> & workers, |
| 109 | double & nll, double & nll2, float * logit_history, float * prob_history |
| 110 | ) { |
| 111 | std::mutex mutex; |
| 112 | int counter = 0; |
| 113 | auto compute = [&mutex, &counter, &nll, &nll2, logit_history, prob_history, n_vocab, logits, tokens, n_token] () { |
| 114 | double local_nll = 0; |
| 115 | double local_nll2 = 0; |
| 116 | while (true) { |
| 117 | std::unique_lock<std::mutex> lock(mutex); |
| 118 | int i = counter++; |
| 119 | if (i >= n_token) { |
| 120 | nll += local_nll; nll2 += local_nll2; |
| 121 | break; |
| 122 | } |
| 123 | lock.unlock(); |
| 124 | const results_log_softmax results = log_softmax(n_vocab, logits: logits + size_t(i)*n_vocab, tok: tokens[i+1]); |
| 125 | const double v = -results.log_softmax; |
| 126 | local_nll += v; |
| 127 | local_nll2 += v*v; |
| 128 | |
| 129 | logit_history[i] = results.logit; |
| 130 | prob_history[i] = results.prob; |
| 131 | } |
| 132 | }; |
| 133 | for (auto & w : workers) { |
| 134 | w = std::thread(compute); |
| 135 | } |
| 136 | compute(); |
| 137 | for (auto & w : workers) { |
| 138 | w.join(); |
| 139 | } |
| 140 | } |
| 141 | |
| 142 | static void process_logits(std::ostream& out, int n_vocab, const float * logits, const int * tokens, int n_token, |
| 143 | std::vector<std::thread> & workers, std::vector<uint16_t> & log_probs, double & nll, double & nll2) { |
| 144 | std::mutex mutex; |
| 145 | const int nv = 2*((n_vocab + 1)/2) + 4; |
| 146 | int counter = 0; |
| 147 | auto compute = [&mutex, &counter, &log_probs, &nll, &nll2, n_vocab, logits, tokens, n_token, nv] () { |
| 148 | double local_nll = 0; |
| 149 | double local_nll2 = 0; |
| 150 | while (true) { |
| 151 | std::unique_lock<std::mutex> lock(mutex); |
| 152 | int i = counter++; |
| 153 | if (i >= n_token) { |
| 154 | nll += local_nll; nll2 += local_nll2; |
| 155 | break; |
| 156 | } |
| 157 | lock.unlock(); |
| 158 | const double v = log_softmax(n_vocab, logits: logits + size_t(i)*n_vocab, log_prob: log_probs.data() + i*nv, tok: tokens[i+1]); |
| 159 | local_nll += v; |
| 160 | local_nll2 += v*v; |
| 161 | } |
| 162 | }; |
| 163 | for (auto & w : workers) { |
| 164 | w = std::thread(compute); |
| 165 | } |
| 166 | compute(); |
| 167 | for (auto & w : workers) { |
| 168 | w.join(); |
| 169 | } |
| 170 | out.write(s: (const char *)log_probs.data(), n: n_token*nv*sizeof(uint16_t)); |
| 171 | } |
| 172 | |
| 173 | struct kl_divergence_result { |
| 174 | double sum_nll = 0.0; |
| 175 | double sum_nll2 = 0.0; |
| 176 | double sum_nll_base = 0.0; |
| 177 | double sum_nll_base2 = 0.0; |
| 178 | double sum_nll_nll_base = 0.0; |
| 179 | double sum_kld = 0.0; |
| 180 | double sum_kld2 = 0.0; |
| 181 | double sum_p_diff = 0.0; |
| 182 | double sum_p_diff2 = 0.0; |
| 183 | double sum_p_diff4 = 0.0; |
| 184 | float max_p_diff = 0.0f; |
| 185 | size_t n_same_top = 0.0; |
| 186 | size_t count = 0.0; |
| 187 | }; |
| 188 | |
| 189 | static std::pair<double, float> log_softmax(int n_vocab, const float * logits, const uint16_t * base_log_prob, int tok, kl_divergence_result & kld) { |
| 190 | float max_logit = logits[0]; |
| 191 | int imax = 0; |
| 192 | for (int i = 1; i < n_vocab; ++i) { |
| 193 | if (logits[i] > max_logit) { |
| 194 | max_logit = logits[i]; |
| 195 | imax = i; |
| 196 | } |
| 197 | } |
| 198 | double sum_exp = 0.0; |
| 199 | for (int i = 0; i < n_vocab; ++i) { |
| 200 | sum_exp += expf(x: logits[i] - max_logit); |
| 201 | } |
| 202 | const float log_sum_exp = log(x: sum_exp); |
| 203 | const float * d = (const float *)base_log_prob; |
| 204 | const float scale = d[0]; |
| 205 | const float min_log_prob = d[1]; |
| 206 | base_log_prob += 4; |
| 207 | |
| 208 | const float nll = max_logit + log_sum_exp - logits[tok]; |
| 209 | kld.sum_nll += nll; |
| 210 | kld.sum_nll2 += nll*nll; |
| 211 | |
| 212 | const float nll_base = -(scale*base_log_prob[tok] + min_log_prob); |
| 213 | kld.sum_nll_base += nll_base; |
| 214 | kld.sum_nll_base2 += nll_base*nll_base; |
| 215 | |
| 216 | kld.sum_nll_nll_base += nll*nll_base; |
| 217 | |
| 218 | max_logit += log_sum_exp; |
| 219 | double sum = 0; |
| 220 | int imax_base = -1; |
| 221 | float p_log_base_max = 0; |
| 222 | for (int i = 0; i < n_vocab; ++i) { |
| 223 | const float p_log_base = scale*base_log_prob[i] + min_log_prob; |
| 224 | if (i == 0 || p_log_base > p_log_base_max) { |
| 225 | p_log_base_max = p_log_base; |
| 226 | imax_base = i; |
| 227 | } |
| 228 | if (p_log_base > -16.f) { |
| 229 | const float p_base = expf(x: p_log_base); |
| 230 | sum += p_base * (p_log_base - logits[i] + max_logit); |
| 231 | } |
| 232 | } |
| 233 | kld.sum_kld += sum; |
| 234 | kld.sum_kld2 += sum*sum; |
| 235 | ++kld.count; |
| 236 | if (imax == imax_base) { |
| 237 | ++kld.n_same_top; |
| 238 | } |
| 239 | |
| 240 | const float p_base = expf(x: -nll_base); |
| 241 | const float p = expf(x: -nll); |
| 242 | const float p_diff = p - p_base; |
| 243 | kld.sum_p_diff += p_diff; |
| 244 | const double p_diff2 = p_diff*p_diff; |
| 245 | kld.sum_p_diff2 += p_diff2; |
| 246 | kld.sum_p_diff4 += p_diff2*p_diff2; |
| 247 | kld.max_p_diff = std::max(a: kld.max_p_diff, b: std::fabs(x: p_diff)); |
| 248 | |
| 249 | return std::make_pair(x&: sum, y: p_diff); |
| 250 | } |
| 251 | |
| 252 | static void process_logits(int n_vocab, const float * logits, const int * tokens, int n_token, |
| 253 | std::vector<std::thread> & workers, const std::vector<uint16_t> & base_log_probs, kl_divergence_result & kld, |
| 254 | float * kld_values, float * p_diff_values) { |
| 255 | std::mutex mutex; |
| 256 | const int nv = 2*((n_vocab + 1)/2) + 4; |
| 257 | int counter = 0; |
| 258 | auto compute = [&mutex, &counter, &base_log_probs, &kld, n_vocab, logits, tokens, n_token, nv, kld_values, p_diff_values] () { |
| 259 | kl_divergence_result local_kld; |
| 260 | while (true) { |
| 261 | std::unique_lock<std::mutex> lock(mutex); |
| 262 | int i = counter++; |
| 263 | if (i >= n_token) { |
| 264 | kld.sum_nll += local_kld.sum_nll; |
| 265 | kld.sum_nll2 += local_kld.sum_nll2; |
| 266 | kld.sum_nll_base += local_kld.sum_nll_base; |
| 267 | kld.sum_nll_base2 += local_kld.sum_nll_base2; |
| 268 | kld.sum_nll_nll_base += local_kld.sum_nll_nll_base; |
| 269 | kld.sum_kld += local_kld.sum_kld; |
| 270 | kld.sum_kld2 += local_kld.sum_kld2; |
| 271 | kld.sum_p_diff += local_kld.sum_p_diff; |
| 272 | kld.sum_p_diff2 += local_kld.sum_p_diff2; |
| 273 | kld.sum_p_diff4 += local_kld.sum_p_diff4; |
| 274 | kld.n_same_top += local_kld.n_same_top; |
| 275 | kld.max_p_diff = std::max(a: kld.max_p_diff, b: local_kld.max_p_diff); |
| 276 | kld.count += local_kld.count; |
| 277 | break; |
| 278 | } |
| 279 | lock.unlock(); |
| 280 | std::pair<double, float> v = log_softmax(n_vocab, logits: logits + size_t(i)*n_vocab, base_log_prob: base_log_probs.data() + i*nv, tok: tokens[i+1], kld&: local_kld); |
| 281 | kld_values[i] = (float)v.first; |
| 282 | p_diff_values[i] = v.second; |
| 283 | } |
| 284 | }; |
| 285 | for (auto & w : workers) { |
| 286 | w = std::thread(compute); |
| 287 | } |
| 288 | compute(); |
| 289 | for (auto & w : workers) { |
| 290 | w.join(); |
| 291 | } |
| 292 | } |
| 293 | |
| 294 | static results_perplexity perplexity_v2(llama_context * ctx, const common_params & params) { |
| 295 | // Download: https://huggingface.co/datasets/ggml-org/ci/resolve/main/wikitext-2-raw-v1.zip |
| 296 | // Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw` |
| 297 | // Output: `perplexity: 13.5106 [114/114]` |
| 298 | // BOS tokens will be added for each chunk before eval |
| 299 | |
| 300 | const llama_model * model = llama_get_model(ctx); |
| 301 | const llama_vocab * vocab = llama_model_get_vocab(model); |
| 302 | |
| 303 | const bool add_bos = llama_vocab_get_add_bos(vocab); |
| 304 | GGML_ASSERT(!llama_vocab_get_add_eos(vocab)); |
| 305 | |
| 306 | LOG_INF("%s: tokenizing the input ..\n" , __func__); |
| 307 | |
| 308 | std::vector<llama_token> tokens = common_tokenize(ctx, text: params.prompt, add_special: true); |
| 309 | |
| 310 | const int n_ctx = llama_n_ctx(ctx); |
| 311 | |
| 312 | if (int(tokens.size()) < 2*n_ctx) { |
| 313 | LOG_ERR("%s: you need at least %d tokens to evaluate perplexity with a context of %d\n" ,__func__,2*n_ctx, |
| 314 | n_ctx); |
| 315 | LOG_ERR("%s: the data file you provided tokenizes to only %zu tokens\n" ,__func__,tokens.size()); |
| 316 | return {.tokens: std::move(tokens), .ppl_value: 0., .logits: {}, .probs: {}}; |
| 317 | } |
| 318 | |
| 319 | std::vector<float> logit_history; |
| 320 | std::vector<float> prob_history; |
| 321 | |
| 322 | logit_history.resize(new_size: tokens.size()); |
| 323 | prob_history.resize(new_size: tokens.size()); |
| 324 | |
| 325 | if (params.ppl_stride <= 0) { |
| 326 | LOG_ERR("%s: stride is %d but must be greater than zero!\n" ,__func__,params.ppl_stride); |
| 327 | return {.tokens: tokens, .ppl_value: -1, .logits: logit_history, .probs: prob_history}; |
| 328 | } |
| 329 | |
| 330 | const int calc_chunk = n_ctx; |
| 331 | |
| 332 | LOG_INF("%s: have %zu tokens. Calculation chunk = %d\n" , __func__, tokens.size(), calc_chunk); |
| 333 | |
| 334 | if (int(tokens.size()) <= calc_chunk) { |
| 335 | LOG_ERR("%s: there are only %zu tokens, this is not enough for a context size of %d and stride %d\n" ,__func__, |
| 336 | tokens.size(), n_ctx, params.ppl_stride); |
| 337 | return {.tokens: tokens, .ppl_value: -1, .logits: logit_history, .probs: prob_history}; |
| 338 | } |
| 339 | |
| 340 | const int n_chunk_max = (tokens.size() - calc_chunk + params.ppl_stride - 1) / params.ppl_stride; |
| 341 | |
| 342 | const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(a: params.n_chunks, b: n_chunk_max); |
| 343 | const int n_batch = params.n_batch; |
| 344 | |
| 345 | const int n_vocab = llama_vocab_n_tokens(vocab); |
| 346 | |
| 347 | int count = 0; |
| 348 | double nll = 0.0; |
| 349 | |
| 350 | LOG_INF("%s: calculating perplexity over %d chunks, batch_size=%d\n" , __func__, n_chunk, n_batch); |
| 351 | |
| 352 | for (int i = 0; i < n_chunk; ++i) { |
| 353 | const int start = i * params.ppl_stride; |
| 354 | const int end = start + calc_chunk; |
| 355 | |
| 356 | const int num_batches = (calc_chunk + n_batch - 1) / n_batch; |
| 357 | //LOG_DBG("%s: evaluating %d...%d using %d batches\n", __func__, start, end, num_batches); |
| 358 | |
| 359 | std::vector<float> logits; |
| 360 | |
| 361 | const auto t_start = std::chrono::high_resolution_clock::now(); |
| 362 | |
| 363 | // clear the KV cache |
| 364 | llama_memory_clear(mem: llama_get_memory(ctx), data: true); |
| 365 | |
| 366 | llama_batch batch = llama_batch_init(n_tokens: n_batch, embd: 0, n_seq_max: 1); |
| 367 | |
| 368 | for (int j = 0; j < num_batches; ++j) { |
| 369 | const int batch_start = start + j * n_batch; |
| 370 | const int batch_size = std::min(a: end - batch_start, b: n_batch); |
| 371 | |
| 372 | common_batch_clear(batch); |
| 373 | for (int i = 0; i < batch_size; i++) { |
| 374 | common_batch_add(batch, id: tokens[batch_start + i], pos: j*n_batch + i, seq_ids: {0}, logits: true); |
| 375 | } |
| 376 | |
| 377 | //LOG_DBG(" Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch); |
| 378 | if (llama_decode(ctx, batch)) { |
| 379 | //LOG_ERR("%s : failed to eval\n", __func__); |
| 380 | llama_batch_free(batch); |
| 381 | return {.tokens: tokens, .ppl_value: -1, .logits: logit_history, .probs: prob_history}; |
| 382 | } |
| 383 | |
| 384 | // save original token and restore it after eval |
| 385 | const auto token_org = tokens[batch_start]; |
| 386 | |
| 387 | // add BOS token for the first batch of each chunk |
| 388 | if (add_bos && j == 0) { |
| 389 | tokens[batch_start] = llama_vocab_bos(vocab); |
| 390 | } |
| 391 | |
| 392 | const auto * batch_logits = llama_get_logits(ctx); |
| 393 | logits.insert(position: logits.end(), first: batch_logits, last: batch_logits + size_t(batch_size) * n_vocab); |
| 394 | |
| 395 | if (j == 0) { |
| 396 | tokens[batch_start] = token_org; |
| 397 | } |
| 398 | } |
| 399 | |
| 400 | llama_batch_free(batch); |
| 401 | |
| 402 | const auto t_end = std::chrono::high_resolution_clock::now(); |
| 403 | |
| 404 | if (i == 0) { |
| 405 | const float t_total = std::chrono::duration<float>(t_end - t_start).count(); |
| 406 | LOG_INF("%s: %.2f seconds per pass - ETA " , __func__, t_total); |
| 407 | int total_seconds = (int)(t_total * n_chunk); |
| 408 | if (total_seconds >= 60*60) { |
| 409 | LOG("%d hours " , total_seconds / (60*60)); |
| 410 | total_seconds = total_seconds % (60*60); |
| 411 | } |
| 412 | LOG("%.2f minutes\n" , total_seconds / 60.0); |
| 413 | } |
| 414 | |
| 415 | //LOG_DBG("%s: using tokens %d...%d\n",__func__,params.n_ctx - params.ppl_stride + start, params.n_ctx + start); |
| 416 | for (int j = n_ctx - params.ppl_stride - 1; j < n_ctx - 1; ++j) { |
| 417 | // Calculate probability of next token, given the previous ones. |
| 418 | const std::vector<float> tok_logits( |
| 419 | logits.begin() + size_t(j + 0) * n_vocab, |
| 420 | logits.begin() + size_t(j + 1) * n_vocab); |
| 421 | |
| 422 | const float prob = softmax(logits: tok_logits)[tokens[start + j + 1]]; |
| 423 | logit_history[start + j + 1] = tok_logits[tokens[start + j + 1]]; |
| 424 | prob_history[start + j + 1] = prob; |
| 425 | |
| 426 | nll += -std::log(x: prob); |
| 427 | ++count; |
| 428 | } |
| 429 | // perplexity is e^(average negative log-likelihood) |
| 430 | if (params.ppl_output_type == 0) { |
| 431 | LOG("[%d]%.4lf," , i + 1, std::exp(nll / count)); |
| 432 | } else { |
| 433 | LOG("%8d %.4lf\n" , i*params.ppl_stride, std::exp(nll / count)); |
| 434 | } |
| 435 | } |
| 436 | LOG("\n" ); |
| 437 | |
| 438 | return {.tokens: tokens, .ppl_value: std::exp(x: nll / count), .logits: logit_history, .probs: prob_history}; |
| 439 | } |
| 440 | |
| 441 | static results_perplexity perplexity(llama_context * ctx, const common_params & params, const int32_t n_ctx) { |
| 442 | if (params.ppl_stride > 0) { |
| 443 | return perplexity_v2(ctx, params); |
| 444 | } |
| 445 | |
| 446 | // Download: https://huggingface.co/datasets/ggml-org/ci/resolve/main/wikitext-2-raw-v1.zip |
| 447 | // Run `./llama-perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw` |
| 448 | // Output: `perplexity: 13.5106 [114/114]` |
| 449 | // BOS tokens will be added for each chunk before eval |
| 450 | |
| 451 | const llama_model * model = llama_get_model(ctx); |
| 452 | const llama_vocab * vocab = llama_model_get_vocab(model); |
| 453 | |
| 454 | const bool add_bos = llama_vocab_get_add_bos(vocab); |
| 455 | GGML_ASSERT(!llama_vocab_get_add_eos(vocab)); |
| 456 | |
| 457 | std::ofstream logits_stream; |
| 458 | if (!params.logits_file.empty()) { |
| 459 | logits_stream.open(s: params.logits_file.c_str(), mode: std::ios::binary); |
| 460 | if (!logits_stream.is_open()) { |
| 461 | LOG_ERR("%s: failed to open %s for writing\n" , __func__, params.logits_file.c_str()); |
| 462 | return {}; |
| 463 | } |
| 464 | LOG_INF("%s: saving all logits to %s\n" , __func__, params.logits_file.c_str()); |
| 465 | logits_stream.write(s: "_logits_" , n: 8); |
| 466 | logits_stream.write(s: reinterpret_cast<const char *>(&n_ctx), n: sizeof(n_ctx)); |
| 467 | } |
| 468 | |
| 469 | auto tim1 = std::chrono::high_resolution_clock::now(); |
| 470 | LOG_INF("%s: tokenizing the input ..\n" , __func__); |
| 471 | |
| 472 | std::vector<llama_token> tokens = common_tokenize(ctx, text: params.prompt, add_special: true); |
| 473 | |
| 474 | auto tim2 = std::chrono::high_resolution_clock::now(); |
| 475 | LOG_INF("%s: tokenization took %g ms\n" ,__func__,1e-3*std::chrono::duration_cast<std::chrono::microseconds>(tim2-tim1).count()); |
| 476 | |
| 477 | if (int(tokens.size()) < 2*n_ctx) { |
| 478 | LOG_ERR("%s: you need at least %d tokens to evaluate perplexity with a context of %d\n" ,__func__,2*n_ctx, |
| 479 | n_ctx); |
| 480 | LOG_ERR("%s: the data file you provided tokenizes to only %zu tokens\n" ,__func__,tokens.size()); |
| 481 | return {.tokens: std::move(tokens), .ppl_value: 0., .logits: {}, .probs: {}}; |
| 482 | } |
| 483 | |
| 484 | std::vector<float> logit_history; |
| 485 | logit_history.resize(new_size: tokens.size()); |
| 486 | |
| 487 | std::vector<float> prob_history; |
| 488 | prob_history.resize(new_size: tokens.size()); |
| 489 | |
| 490 | const int n_chunk_max = tokens.size() / n_ctx; |
| 491 | |
| 492 | const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(a: params.n_chunks, b: n_chunk_max); |
| 493 | const int n_batch = params.n_batch; |
| 494 | |
| 495 | const int n_vocab = llama_vocab_n_tokens(vocab); |
| 496 | |
| 497 | int count = 0; |
| 498 | double nll = 0.0; |
| 499 | double nll2 = 0.0; |
| 500 | |
| 501 | const int num_batches = (n_ctx + n_batch - 1) / n_batch; |
| 502 | const int n_seq = std::max(a: 1, b: n_batch / n_ctx); |
| 503 | |
| 504 | GGML_ASSERT(n_batch < n_ctx || n_batch % n_ctx == 0); |
| 505 | GGML_ASSERT(params.n_ctx == n_seq * n_ctx); |
| 506 | |
| 507 | llama_batch batch = llama_batch_init(n_tokens: std::min(a: n_batch, b: n_ctx*n_seq), embd: 0, n_seq_max: 1); |
| 508 | |
| 509 | std::vector<float> logits; |
| 510 | if (num_batches > 1) { |
| 511 | logits.reserve(n: size_t(n_ctx) * n_vocab); |
| 512 | } |
| 513 | |
| 514 | LOG_INF("%s: calculating perplexity over %d chunks, n_ctx=%d, batch_size=%d, n_seq=%d\n" , __func__, n_chunk, n_ctx, n_batch, n_seq); |
| 515 | |
| 516 | std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1); |
| 517 | |
| 518 | std::vector<uint16_t> log_probs; |
| 519 | if (!params.logits_file.empty()) { |
| 520 | logits_stream.write(s: (const char *)&n_vocab, n: sizeof(n_vocab)); |
| 521 | logits_stream.write(s: (const char *)&n_chunk, n: sizeof(n_chunk)); |
| 522 | logits_stream.write(s: (const char *)tokens.data(), n: n_chunk*n_ctx*sizeof(tokens[0])); |
| 523 | const int nv = 2*((n_vocab + 1)/2) + 4; |
| 524 | log_probs.resize(new_size: n_ctx * nv); |
| 525 | } |
| 526 | |
| 527 | // We get the logits for all the tokens in the context window (params.n_ctx) |
| 528 | // from llama_decode below. Now, based on https://huggingface.co/docs/transformers/perplexity, |
| 529 | // calculate the perplexity over the last half of the window (so the model always has |
| 530 | // some context to predict the token). |
| 531 | // |
| 532 | // We rely on the fact that attention in the forward pass only looks at previous |
| 533 | // tokens here, so the logits returned for each token are an accurate representation |
| 534 | // of what the model would have predicted at that point. |
| 535 | // |
| 536 | // Example, we have a context window of 512, we will compute perplexity for each of the |
| 537 | // last 256 tokens. Then, we split the input up into context window size chunks to |
| 538 | // process the entire prompt. |
| 539 | const int first = n_ctx/2; |
| 540 | |
| 541 | for (int i = 0; i < n_chunk; i += n_seq) { |
| 542 | const int start = i * n_ctx; |
| 543 | const int end = start + n_ctx; |
| 544 | |
| 545 | const int n_seq_batch = std::min(a: n_seq, b: n_chunk - i); |
| 546 | |
| 547 | const auto t_start = std::chrono::high_resolution_clock::now(); |
| 548 | |
| 549 | // clear the KV cache |
| 550 | llama_memory_clear(mem: llama_get_memory(ctx), data: true); |
| 551 | |
| 552 | for (int j = 0; j < num_batches; ++j) { |
| 553 | const int batch_start = start + j * n_batch; |
| 554 | const int batch_size = std::min(a: end - batch_start, b: n_batch); |
| 555 | |
| 556 | int n_outputs = 0; |
| 557 | |
| 558 | batch.n_tokens = 0; |
| 559 | for (int seq = 0; seq < n_seq_batch; seq++) { |
| 560 | int seq_start = batch_start + seq*n_ctx; |
| 561 | |
| 562 | // save original token and restore it after decode |
| 563 | const auto token_org = tokens[seq_start]; |
| 564 | |
| 565 | // add BOS token for the first batch of each chunk |
| 566 | if (add_bos && j == 0) { |
| 567 | tokens[seq_start] = llama_vocab_bos(vocab); |
| 568 | } |
| 569 | |
| 570 | for (int k = 0; k < batch_size; ++k) { |
| 571 | const int idx = seq*n_ctx + k; |
| 572 | batch.token [idx] = tokens[seq_start + k]; |
| 573 | batch.pos [idx] = j*n_batch + k; |
| 574 | batch.n_seq_id[idx] = 1; |
| 575 | batch.seq_id [idx][0] = seq; |
| 576 | batch.logits [idx] = batch.pos[idx] >= first ? 1 : 0; |
| 577 | |
| 578 | n_outputs += batch.logits[idx] != 0; |
| 579 | } |
| 580 | batch.n_tokens += batch_size; |
| 581 | |
| 582 | // restore the original token in case it was set to BOS |
| 583 | tokens[seq_start] = token_org; |
| 584 | } |
| 585 | |
| 586 | if (llama_decode(ctx, batch)) { |
| 587 | LOG_INF("%s : failed to decode\n" , __func__); |
| 588 | return {.tokens: tokens, .ppl_value: -1, .logits: logit_history, .probs: prob_history}; |
| 589 | } |
| 590 | |
| 591 | if (num_batches > 1 && n_outputs > 0) { |
| 592 | const auto * batch_logits = llama_get_logits(ctx); |
| 593 | logits.insert(position: logits.end(), first: batch_logits, last: batch_logits + size_t(n_outputs) * n_vocab); |
| 594 | } |
| 595 | } |
| 596 | |
| 597 | |
| 598 | if (i == 0) { |
| 599 | llama_synchronize(ctx); |
| 600 | const auto t_end = std::chrono::high_resolution_clock::now(); |
| 601 | const float t_total = std::chrono::duration<float>(t_end - t_start).count(); |
| 602 | LOG_INF("%s: %.2f seconds per pass - ETA " , __func__, t_total); |
| 603 | int total_seconds = (int)(t_total*n_chunk/n_seq); |
| 604 | if (total_seconds >= 60*60) { |
| 605 | LOG("%d hours " , total_seconds / (60*60)); |
| 606 | total_seconds = total_seconds % (60*60); |
| 607 | } |
| 608 | LOG("%.2f minutes\n" , total_seconds / 60.0); |
| 609 | } |
| 610 | |
| 611 | for (int seq = 0; seq < n_seq_batch; seq++) { |
| 612 | const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits_ith(ctx, i: seq*n_ctx + first); |
| 613 | |
| 614 | llama_token * tokens_data = tokens.data() + start + seq*n_ctx + first; |
| 615 | if (!params.logits_file.empty()) { |
| 616 | process_logits(out&: logits_stream, n_vocab, logits: all_logits, |
| 617 | tokens: tokens_data, n_token: n_ctx - 1 - first, |
| 618 | workers, log_probs, nll, nll2); |
| 619 | } else { |
| 620 | process_logits(n_vocab, logits: all_logits, |
| 621 | tokens: tokens_data, n_token: n_ctx - 1 - first, |
| 622 | workers, nll, nll2, |
| 623 | logit_history: logit_history.data() + start + seq*n_ctx + first, |
| 624 | prob_history: prob_history.data() + start + seq*n_ctx + first); |
| 625 | } |
| 626 | count += n_ctx - first - 1; |
| 627 | |
| 628 | // perplexity is e^(average negative log-likelihood) |
| 629 | if (params.ppl_output_type == 0) { |
| 630 | LOG("[%d]%.4lf," , i + seq + 1, std::exp(nll / count)); |
| 631 | } else { |
| 632 | double av = nll/count; |
| 633 | double av2 = nll2/count - av*av; |
| 634 | if (av2 > 0) { |
| 635 | av2 = sqrt(x: av2/(count-1)); |
| 636 | } |
| 637 | LOG("%8d %.4lf %4lf %4lf\n" , i*n_ctx, std::exp(nll / count), av, av2); |
| 638 | } |
| 639 | } |
| 640 | |
| 641 | logits.clear(); |
| 642 | } |
| 643 | LOG("\n" ); |
| 644 | |
| 645 | nll2 /= count; |
| 646 | nll /= count; |
| 647 | const double ppl = exp(x: nll); |
| 648 | nll2 -= nll * nll; |
| 649 | if (nll2 > 0) { |
| 650 | nll2 = sqrt(x: nll2/(count-1)); |
| 651 | LOG_INF("Final estimate: PPL = %.4lf +/- %.5lf\n" , ppl, nll2*ppl); |
| 652 | } else { |
| 653 | LOG_ERR("Unexpected negative standard deviation of log(prob)\n" ); |
| 654 | } |
| 655 | |
| 656 | llama_batch_free(batch); |
| 657 | |
| 658 | return {.tokens: tokens, .ppl_value: ppl, .logits: logit_history, .probs: prob_history}; |
| 659 | } |
| 660 | |
| 661 | static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<float> & batch_logits, int n_batch, int n_vocab) { |
| 662 | int prev_outputs = 0; |
| 663 | for (int i = 0; i < (int) batch.n_tokens; i += n_batch) { |
| 664 | const int n_tokens = std::min<int>(a: n_batch, b: batch.n_tokens - i); |
| 665 | |
| 666 | llama_batch batch_view = { |
| 667 | .n_tokens: n_tokens, |
| 668 | .token: batch.token + i, |
| 669 | .embd: nullptr, |
| 670 | .pos: batch.pos + i, |
| 671 | .n_seq_id: batch.n_seq_id + i, |
| 672 | .seq_id: batch.seq_id + i, |
| 673 | .logits: batch.logits + i, |
| 674 | }; |
| 675 | |
| 676 | const int ret = llama_decode(ctx, batch: batch_view); |
| 677 | if (ret != 0) { |
| 678 | LOG_ERR("failed to decode the batch, n_batch = %d, ret = %d\n" , n_batch, ret); |
| 679 | return false; |
| 680 | } |
| 681 | |
| 682 | int n_outputs = 0; |
| 683 | for (int i = 0; i < n_tokens; ++i) { |
| 684 | n_outputs += batch_view.logits[i] != 0; |
| 685 | } |
| 686 | |
| 687 | memcpy(dest: batch_logits.data() + size_t(prev_outputs)*n_vocab, src: llama_get_logits(ctx), n: size_t(n_outputs)*n_vocab*sizeof(float)); |
| 688 | |
| 689 | prev_outputs += n_outputs; |
| 690 | } |
| 691 | |
| 692 | return true; |
| 693 | } |
| 694 | |
| 695 | #define K_TOKEN_CHUNK 4 |
| 696 | |
| 697 | static void compute_logprobs(const float * batch_logits, int n_vocab, std::vector<std::thread>& workers, |
| 698 | const std::vector<std::pair<size_t, llama_token>>& eval_pairs, std::vector<float>& eval_results) { |
| 699 | if (eval_results.size() != eval_pairs.size()) { |
| 700 | eval_results.resize(new_size: eval_pairs.size()); |
| 701 | } |
| 702 | if (eval_pairs.empty()) { |
| 703 | return; |
| 704 | } |
| 705 | |
| 706 | size_t max_threads = std::min(a: (eval_pairs.size() + K_TOKEN_CHUNK - 1)/K_TOKEN_CHUNK, b: workers.size()); |
| 707 | |
| 708 | std::atomic<int> counter(0); |
| 709 | auto compute = [&counter, &eval_pairs, &eval_results, batch_logits, n_vocab] () { |
| 710 | float local_logprobs[K_TOKEN_CHUNK]; |
| 711 | while (true) { |
| 712 | const size_t first = counter.fetch_add(K_TOKEN_CHUNK, m: std::memory_order_relaxed); |
| 713 | if (first >= eval_results.size()) { |
| 714 | break; |
| 715 | } |
| 716 | const size_t last = std::min(a: first + K_TOKEN_CHUNK, b: eval_results.size()); |
| 717 | for (size_t i = first; i < last; ++i) { |
| 718 | const auto * logits = batch_logits + eval_pairs[i].first * n_vocab; |
| 719 | float max_logit = logits[0]; |
| 720 | for (int j = 1; j < n_vocab; ++j) { |
| 721 | max_logit = std::max(a: max_logit, b: logits[j]); |
| 722 | } |
| 723 | float sum_p = 0.f; |
| 724 | for (int j = 0; j < n_vocab; ++j) { |
| 725 | sum_p += expf(x: logits[j] - max_logit); |
| 726 | } |
| 727 | local_logprobs[i - first] = logits[eval_pairs[i].second] - max_logit - std::log(x: sum_p); |
| 728 | } |
| 729 | std::memcpy(dest: eval_results.data() + first, src: local_logprobs, n: (last - first)*sizeof(float)); |
| 730 | } |
| 731 | }; |
| 732 | |
| 733 | for (size_t it = 0; it < max_threads; ++it) { |
| 734 | workers[it] = std::thread(compute); |
| 735 | } |
| 736 | for (size_t it = 0; it < max_threads; ++it) { |
| 737 | workers[it].join(); |
| 738 | } |
| 739 | } |
| 740 | |
| 741 | static void hellaswag_score(llama_context * ctx, const common_params & params) { |
| 742 | const llama_model * model = llama_get_model(ctx); |
| 743 | const llama_vocab * vocab = llama_model_get_vocab(model); |
| 744 | |
| 745 | // Calculates hellaswag score (acc_norm) from prompt |
| 746 | // |
| 747 | // Data extracted from the HellaSwag validation dataset (MIT license) https://github.com/rowanz/hellaswag/blob/master/data/hellaswag_val.jsonl |
| 748 | // All used data fields are preprocessed as in https://github.com/EleutherAI/lm-evaluation-harness/blob/df3da98c5405deafd519c2ddca52bb7c3fe36bef/lm_eval/tasks/hellaswag.py#L62-L68 |
| 749 | // |
| 750 | // All 10042 tasks should be extracted to keep the results standardized like other implementations. |
| 751 | // |
| 752 | // Datafile layout: |
| 753 | // ['??'] denotes json fields |
| 754 | // 6 lines per task: |
| 755 | // ['activity_label'] + ": " +['ctx'] - The first part of the query, the context |
| 756 | // ['label'] - The index the best common sense ending aka gold ending |
| 757 | // ['endings'][0] - Endings added to the first part of the query |
| 758 | // ['endings'][1] |
| 759 | // ['endings'][2] |
| 760 | // ['endings'][3] |
| 761 | |
| 762 | std::vector<std::string> prompt_lines; |
| 763 | std::istringstream strstream(params.prompt); |
| 764 | std::string line; |
| 765 | |
| 766 | while (std::getline(in&: strstream,str&: line,delim: '\n')) { |
| 767 | prompt_lines.push_back(x: line); |
| 768 | } |
| 769 | |
| 770 | if (prompt_lines.size() % 6 != 0) { |
| 771 | LOG_ERR("%s : number of lines in prompt not a multiple of 6.\n" , __func__); |
| 772 | return; |
| 773 | } |
| 774 | |
| 775 | size_t hs_task_count = prompt_lines.size()/6; |
| 776 | LOG_INF("%s : loaded %zu tasks from prompt.\n" , __func__, hs_task_count); |
| 777 | |
| 778 | const bool is_spm = llama_vocab_type(vocab) == LLAMA_VOCAB_TYPE_SPM; |
| 779 | LOG_INF("================================= is_spm = %d\n" , is_spm); |
| 780 | |
| 781 | // The tasks should be randomized so the score stabilizes quickly. |
| 782 | bool randomize_tasks = true; |
| 783 | |
| 784 | // Number of tasks to use when computing the score |
| 785 | if (params.hellaswag_tasks < hs_task_count) { |
| 786 | hs_task_count = params.hellaswag_tasks; |
| 787 | } |
| 788 | |
| 789 | // The random seed should not impact the final result if the computation is done over enough tasks, so kept hardcoded for now |
| 790 | std::mt19937 rng(1); |
| 791 | |
| 792 | // Dataholder for hellaswag tasks |
| 793 | struct hs_data_t { |
| 794 | std::string context; |
| 795 | size_t gold_ending_idx; |
| 796 | std::string ending[4]; |
| 797 | size_t ending_logprob_count[4]; |
| 798 | double ending_logprob[4]; |
| 799 | |
| 800 | size_t i_logits; // starting index of logits in the llama_batch |
| 801 | size_t common_prefix; // max number of initial tokens that are the same in all sentences |
| 802 | size_t required_tokens; // needed number of tokens to evaluate all 4 endings |
| 803 | std::vector<llama_token> seq_tokens[4]; |
| 804 | }; |
| 805 | |
| 806 | LOG_INF("%s : selecting %zu %s tasks.\n" , __func__, hs_task_count, (randomize_tasks?"randomized" :"the first" ) ); |
| 807 | |
| 808 | // Select and read data from prompt lines |
| 809 | std::vector<hs_data_t> hs_data(hs_task_count); |
| 810 | for (size_t i = 0; i < hs_task_count; i++) { |
| 811 | size_t idx = i; |
| 812 | |
| 813 | auto & hs_cur = hs_data[i]; |
| 814 | |
| 815 | // Select a random example of those left in the prompt |
| 816 | if (randomize_tasks) { |
| 817 | std::uniform_int_distribution<size_t> dist(0, prompt_lines.size()/6-1 ) ; |
| 818 | idx = dist(rng); |
| 819 | } |
| 820 | |
| 821 | hs_cur.context = prompt_lines[idx*6]; |
| 822 | hs_cur.gold_ending_idx = std::stoi( str: prompt_lines[idx*6+1] ); |
| 823 | for (size_t j = 0; j < 4; j++) { |
| 824 | hs_cur.ending[j] = prompt_lines[idx*6+2+j]; |
| 825 | hs_cur.seq_tokens[j] = common_tokenize(ctx, text: hs_cur.context + " " + hs_cur.ending[j], add_special: true); |
| 826 | } |
| 827 | |
| 828 | // determine the common prefix of the endings |
| 829 | hs_cur.common_prefix = 0; |
| 830 | for (size_t k = 0; k < hs_cur.seq_tokens[0].size(); k++) { |
| 831 | if (hs_cur.seq_tokens[0][k] != hs_cur.seq_tokens[1][k] || |
| 832 | hs_cur.seq_tokens[0][k] != hs_cur.seq_tokens[2][k] || |
| 833 | hs_cur.seq_tokens[0][k] != hs_cur.seq_tokens[3][k]) { |
| 834 | break; |
| 835 | } |
| 836 | hs_cur.common_prefix++; |
| 837 | } |
| 838 | hs_cur.required_tokens = hs_cur.common_prefix + |
| 839 | hs_cur.seq_tokens[0].size() - hs_cur.common_prefix + |
| 840 | hs_cur.seq_tokens[1].size() - hs_cur.common_prefix + |
| 841 | hs_cur.seq_tokens[2].size() - hs_cur.common_prefix + |
| 842 | hs_cur.seq_tokens[3].size() - hs_cur.common_prefix; |
| 843 | |
| 844 | //GGML_ASSERT(hs_cur.common_prefix >= ::llama_tokenize(ctx, hs_cur.context, true).size()); |
| 845 | |
| 846 | // Delete the selected random example from the prompt |
| 847 | if (randomize_tasks) { |
| 848 | prompt_lines.erase( first: std::next(x: prompt_lines.begin(),n: idx*6) , last: std::next(x: prompt_lines.begin(),n: idx*6+6) ); |
| 849 | } |
| 850 | } |
| 851 | |
| 852 | LOG_INF("%s : calculating hellaswag score over selected tasks.\n" , __func__); |
| 853 | |
| 854 | LOG("\ntask\tacc_norm\t95%% confidence interval\n" ); |
| 855 | |
| 856 | double acc = 0.0f; |
| 857 | |
| 858 | const int n_ctx = llama_n_ctx(ctx); |
| 859 | const int n_batch = params.n_batch; |
| 860 | |
| 861 | const int n_vocab = llama_vocab_n_tokens(vocab); |
| 862 | |
| 863 | const int max_tasks_per_batch = 32; |
| 864 | const int max_seq = std::min(a: 4*max_tasks_per_batch, b: (int) llama_n_seq_max(ctx)); |
| 865 | |
| 866 | llama_batch batch = llama_batch_init(n_tokens: n_ctx, embd: 0, n_seq_max: 4); |
| 867 | |
| 868 | std::vector<float> tok_logits(n_vocab); |
| 869 | // TODO: this could be made smaller; it's currently the worst-case size |
| 870 | std::vector<float> batch_logits(size_t(n_ctx)*n_vocab); |
| 871 | |
| 872 | std::vector<std::pair<size_t, llama_token>> eval_pairs; |
| 873 | std::vector<float> eval_results; |
| 874 | std::vector<std::thread> workers(std::thread::hardware_concurrency()); |
| 875 | |
| 876 | for (size_t i0 = 0; i0 < hs_task_count; i0++) { |
| 877 | int n_cur = 0; |
| 878 | |
| 879 | size_t i1 = i0; |
| 880 | size_t i_logits = 0; // this tells us how many logits were needed before this point in the batch |
| 881 | |
| 882 | common_batch_clear(batch); |
| 883 | |
| 884 | // batch as much tasks as possible into the available context |
| 885 | // each task has 4 unique sequence ids - one for each ending |
| 886 | // the common prefix is shared among the 4 sequences to save tokens |
| 887 | // we extract logits only from the last common token and from all ending tokens of each sequence |
| 888 | while (n_cur + (int) hs_data[i1].required_tokens <= n_ctx) { |
| 889 | auto & hs_cur = hs_data[i1]; |
| 890 | int n_logits = 0; |
| 891 | |
| 892 | const int s0 = 4*(i1 - i0); |
| 893 | if (s0 + 4 > max_seq) { |
| 894 | break; |
| 895 | } |
| 896 | |
| 897 | for (size_t i = 0; i < hs_cur.common_prefix; ++i) { |
| 898 | common_batch_add(batch, id: hs_cur.seq_tokens[0][i], pos: i, seq_ids: { s0 + 0, s0 + 1, s0 + 2, s0 + 3 }, logits: false); |
| 899 | } |
| 900 | batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix |
| 901 | n_logits += 1; |
| 902 | |
| 903 | for (int s = 0; s < 4; ++s) { |
| 904 | const size_t seq_tokens_size = hs_cur.seq_tokens[s].size(); |
| 905 | // TODO: don't evaluate the last token of each sequence |
| 906 | for (size_t i = hs_cur.common_prefix; i < seq_tokens_size; ++i) { |
| 907 | const bool needs_logits = i < seq_tokens_size - 1; |
| 908 | common_batch_add(batch, id: hs_cur.seq_tokens[s][i], pos: i, seq_ids: { s0 + s }, logits: needs_logits); |
| 909 | n_logits += needs_logits; |
| 910 | } |
| 911 | } |
| 912 | |
| 913 | hs_cur.i_logits = i_logits; |
| 914 | i_logits += n_logits; |
| 915 | |
| 916 | n_cur += hs_data[i1].required_tokens; |
| 917 | if (++i1 == hs_task_count) { |
| 918 | break; |
| 919 | } |
| 920 | } |
| 921 | |
| 922 | if (i0 == i1) { |
| 923 | LOG_ERR("%s : task %zu does not fit in the context window (requires %lu tokens)\n" , __func__, i0, hs_data[i0].required_tokens); |
| 924 | return; |
| 925 | } |
| 926 | |
| 927 | llama_memory_clear(mem: llama_get_memory(ctx), data: true); |
| 928 | |
| 929 | // decode all tasks [i0, i1) |
| 930 | if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) { |
| 931 | LOG_ERR("%s: llama_decode() failed\n" , __func__); |
| 932 | return; |
| 933 | } |
| 934 | |
| 935 | // Compute log-probs in parallel |
| 936 | // First we collect all tasks |
| 937 | eval_pairs.clear(); |
| 938 | for (size_t i = i0; i < i1; ++i) { |
| 939 | auto & hs_cur = hs_data[i]; |
| 940 | size_t li = 1; // skip the last logit of the common prefix (computed separately below) |
| 941 | for (int s = 0; s < 4; ++s) { |
| 942 | for (size_t j = hs_cur.common_prefix; j < hs_cur.seq_tokens[s].size() - 1; j++) { |
| 943 | eval_pairs.emplace_back(args: hs_cur.i_logits + li++, args&: hs_cur.seq_tokens[s][j + 1]); |
| 944 | } |
| 945 | } |
| 946 | } |
| 947 | // Then we do the actual calculation |
| 948 | compute_logprobs(batch_logits: batch_logits.data(), n_vocab, workers, eval_pairs, eval_results); |
| 949 | |
| 950 | size_t ir = 0; |
| 951 | |
| 952 | // compute the logprobs for each ending of the decoded tasks |
| 953 | for (size_t i = i0; i < i1; ++i) { |
| 954 | auto & hs_cur = hs_data[i]; |
| 955 | |
| 956 | // get the logits of the last token of the common prefix |
| 957 | std::memcpy(dest: tok_logits.data(), src: batch_logits.data() + hs_cur.i_logits*n_vocab, n: n_vocab*sizeof(float)); |
| 958 | |
| 959 | const auto first_probs = softmax(logits: tok_logits); |
| 960 | |
| 961 | for (int s = 0; s < 4; ++s) { |
| 962 | hs_cur.ending_logprob_count[s] = 1; |
| 963 | hs_cur.ending_logprob[s] = std::log(x: first_probs[hs_cur.seq_tokens[s][hs_cur.common_prefix]]); |
| 964 | for (size_t j = hs_cur.common_prefix; j < hs_cur.seq_tokens[s].size() - 1; j++) { |
| 965 | hs_cur.ending_logprob[s] += eval_results[ir++]; |
| 966 | hs_cur.ending_logprob_count[s]++; |
| 967 | } |
| 968 | hs_cur.ending_logprob[s] /= hs_cur.ending_logprob_count[s]; |
| 969 | } |
| 970 | |
| 971 | // Find the ending with maximum logprob |
| 972 | size_t ending_logprob_max_idx = 0; |
| 973 | double ending_logprob_max_val = hs_cur.ending_logprob[0]; |
| 974 | for (size_t s = 1; s < 4; s++) { |
| 975 | if (hs_cur.ending_logprob[s] > ending_logprob_max_val) { |
| 976 | ending_logprob_max_idx = s; |
| 977 | ending_logprob_max_val = hs_cur.ending_logprob[s]; |
| 978 | } |
| 979 | } |
| 980 | |
| 981 | //LOG("max logprob ending idx %lu, gold ending idx %lu\n", ending_logprob_max_idx, hs_cur.gold_ending_idx); |
| 982 | |
| 983 | // If the gold ending got the maximum logprobe add one accuracy point |
| 984 | if (ending_logprob_max_idx == hs_cur.gold_ending_idx) { |
| 985 | acc += 1.0; |
| 986 | } |
| 987 | |
| 988 | double freq = acc / double(i + 1); |
| 989 | |
| 990 | const double za = 1.95996398454; |
| 991 | |
| 992 | // // Wald normal approx |
| 993 | // double conf =za*sqrt(freq*(1-freq)/double(i + 1)); |
| 994 | // LOG("%zu\t%.8lf +/- %.8lf\n", i + 1, freq*100.0, conf*100.0); |
| 995 | |
| 996 | // Wilson score interval, more accurate |
| 997 | double z = za * za / double(i + 1); |
| 998 | double cnf = z * sqrt(x: double(i + 1) * (4.0 * freq * (1 - freq) + z)) / (za + za); |
| 999 | double a = (freq + z * 0.5 - cnf) / (1.0 + z); |
| 1000 | double b = (freq + z * 0.5 + cnf) / (1.0 + z); |
| 1001 | |
| 1002 | // Print the accumulated accuracy mean x 100 and confidence interval |
| 1003 | LOG("%zu\t%3.8lf%%\t[%3.4lf%%, %3.4lf%%]\n" , i + 1, freq * 100.0, a * 100.0, b * 100.0); |
| 1004 | } |
| 1005 | |
| 1006 | i0 = i1 - 1; |
| 1007 | } |
| 1008 | |
| 1009 | llama_batch_free(batch); |
| 1010 | |
| 1011 | LOG("\n" ); |
| 1012 | } |
| 1013 | |
| 1014 | struct winogrande_entry { |
| 1015 | std::string first; |
| 1016 | std::string second; |
| 1017 | std::array<std::string, 2> choices; |
| 1018 | int answer; |
| 1019 | |
| 1020 | size_t i_logits; |
| 1021 | size_t common_prefix; |
| 1022 | size_t required_tokens; |
| 1023 | size_t n_base1; // number of tokens for context + choice 1 |
| 1024 | size_t n_base2; // number of tokens for context + choice 2 |
| 1025 | std::vector<llama_token> seq_tokens[2]; |
| 1026 | }; |
| 1027 | |
| 1028 | static std::vector<winogrande_entry> load_winogrande_from_csv(const std::string & prompt) { |
| 1029 | std::vector<winogrande_entry> result; |
| 1030 | std::istringstream in(prompt); |
| 1031 | std::string line; |
| 1032 | std::array<int, 4> comma_pos; |
| 1033 | while (true) { |
| 1034 | std::getline(is&: in, str&: line); |
| 1035 | if (in.fail() || in.eof()) break; |
| 1036 | int ipos = 0; |
| 1037 | bool quote_open = false; |
| 1038 | for (int i = 0; i < int(line.size()); ++i) { |
| 1039 | if (!quote_open) { |
| 1040 | if (line[i] == ',') { |
| 1041 | comma_pos[ipos++] = i; |
| 1042 | if (ipos == 4) break; |
| 1043 | } |
| 1044 | else if (line[i] == '"') { |
| 1045 | quote_open = true; |
| 1046 | } |
| 1047 | } |
| 1048 | else { |
| 1049 | if (line[i] == '"') { |
| 1050 | quote_open = false; |
| 1051 | } |
| 1052 | } |
| 1053 | } |
| 1054 | if (ipos != 4) { |
| 1055 | LOG_ERR("%s: failed to find comma separators in <%s>\n" , __func__, line.c_str()); |
| 1056 | continue; |
| 1057 | } |
| 1058 | auto sentence = line[comma_pos[0]+1] == '"' ? line.substr(pos: comma_pos[0]+2, n: comma_pos[1] - comma_pos[0] - 3) |
| 1059 | : line.substr(pos: comma_pos[0]+1, n: comma_pos[1] - comma_pos[0] - 1); |
| 1060 | auto choice1 = line.substr(pos: comma_pos[1]+1, n: comma_pos[2] - comma_pos[1] - 1); |
| 1061 | auto choice2 = line.substr(pos: comma_pos[2]+1, n: comma_pos[3] - comma_pos[2] - 1); |
| 1062 | auto answer = line.substr(pos: comma_pos[3]+1, n: line.size() - comma_pos[3] - 1); |
| 1063 | auto index = line.substr(pos: 0, n: comma_pos[0]); |
| 1064 | int where = 0; |
| 1065 | for ( ; where < int(sentence.size()); ++where) { |
| 1066 | if (sentence[where] == '_') break; |
| 1067 | } |
| 1068 | if (where == int(sentence.size())) { |
| 1069 | LOG_ERR("%s: no _ in <%s>\n" , __func__, sentence.c_str()); |
| 1070 | continue; |
| 1071 | } |
| 1072 | std::istringstream stream(answer.c_str()); |
| 1073 | int i_answer; stream >> i_answer; |
| 1074 | if (stream.fail() || i_answer < 1 || i_answer > 2) { |
| 1075 | LOG_ERR("%s: failed to parse answer <%s>\n" , __func__, answer.c_str()); |
| 1076 | continue; |
| 1077 | } |
| 1078 | result.emplace_back(); |
| 1079 | auto& wg = result.back(); |
| 1080 | wg.first = sentence.substr(pos: 0, n: where); |
| 1081 | wg.second = sentence.substr(pos: where + 1, n: sentence.size() - where - 1); |
| 1082 | wg.choices[0] = std::move(choice1); |
| 1083 | wg.choices[1] = std::move(choice2); |
| 1084 | wg.answer = i_answer; |
| 1085 | } |
| 1086 | return result; |
| 1087 | } |
| 1088 | |
| 1089 | /* |
| 1090 | * Evaluates the Winogrande score. |
| 1091 | * Uses a CSV containing task index, dentence, choice 1, choice 2, answer (1 or 2) |
| 1092 | * You can get one such dataset from e.g. https://huggingface.co/datasets/ikawrakow/winogrande-eval-for-llama.cpp |
| 1093 | * As an example, the 1st row in the above dataset is |
| 1094 | * |
| 1095 | * 0,Sarah was a much better surgeon than Maria so _ always got the easier cases.,Sarah,Maria,2 |
| 1096 | * |
| 1097 | */ |
| 1098 | static void winogrande_score(llama_context * ctx, const common_params & params) { |
| 1099 | const llama_model * model = llama_get_model(ctx); |
| 1100 | const llama_vocab * vocab = llama_model_get_vocab(model); |
| 1101 | |
| 1102 | constexpr int k_min_trailing_ctx = 3; |
| 1103 | |
| 1104 | auto data = load_winogrande_from_csv(prompt: params.prompt); |
| 1105 | if (data.empty()) { |
| 1106 | LOG_ERR("%s: no tasks\n" , __func__); |
| 1107 | return; |
| 1108 | } |
| 1109 | |
| 1110 | LOG_INF("%s : loaded %zu tasks from prompt.\n" , __func__, data.size()); |
| 1111 | |
| 1112 | if (params.winogrande_tasks > 0 && params.winogrande_tasks < data.size()) { |
| 1113 | LOG_INF("%s : selecting %zu random tasks\n" , __func__, params.winogrande_tasks); |
| 1114 | std::mt19937 rng(1); |
| 1115 | std::vector<int> aux(data.size()); |
| 1116 | for (int i = 0; i < int(data.size()); ++i) { |
| 1117 | aux[i] = i; |
| 1118 | } |
| 1119 | float scale = 1/(1.f + (float)rng.max()); |
| 1120 | std::vector<winogrande_entry> selected; |
| 1121 | selected.resize(new_size: params.winogrande_tasks); |
| 1122 | for (int i = 0; i < int(params.winogrande_tasks); ++i) { |
| 1123 | int j = int(scale*rng()*aux.size()); |
| 1124 | selected[i] = std::move(data[aux[j]]); |
| 1125 | aux[j] = aux.back(); |
| 1126 | aux.pop_back(); |
| 1127 | } |
| 1128 | data = std::move(selected); |
| 1129 | } |
| 1130 | |
| 1131 | LOG_INF("%s : tokenizing selected tasks\n" , __func__); |
| 1132 | |
| 1133 | for (auto & task : data) { |
| 1134 | task.seq_tokens[0] = common_tokenize(ctx, text: task.first + task.choices[0] + task.second, add_special: true); |
| 1135 | task.seq_tokens[1] = common_tokenize(ctx, text: task.first + task.choices[1] + task.second, add_special: true); |
| 1136 | |
| 1137 | task.common_prefix = 0; |
| 1138 | for (size_t k = 0; k < task.seq_tokens[0].size(); k++) { |
| 1139 | if (task.seq_tokens[0][k] != task.seq_tokens[1][k]) { |
| 1140 | break; |
| 1141 | } |
| 1142 | task.common_prefix++; |
| 1143 | } |
| 1144 | |
| 1145 | // TODO: the last token of each of the sequences don't need to be evaluated |
| 1146 | task.required_tokens = task.common_prefix + |
| 1147 | task.seq_tokens[0].size() - task.common_prefix + |
| 1148 | task.seq_tokens[1].size() - task.common_prefix; |
| 1149 | |
| 1150 | task.n_base1 = common_tokenize(ctx, text: task.first + task.choices[0], add_special: true).size(); |
| 1151 | task.n_base2 = common_tokenize(ctx, text: task.first + task.choices[1], add_special: true).size(); |
| 1152 | } |
| 1153 | |
| 1154 | LOG_INF("%s : calculating winogrande score over selected tasks.\n" , __func__); |
| 1155 | |
| 1156 | const int n_ctx = llama_n_ctx(ctx); |
| 1157 | const int n_batch = params.n_batch; |
| 1158 | |
| 1159 | const int n_vocab = llama_vocab_n_tokens(vocab); |
| 1160 | |
| 1161 | const int max_tasks_per_batch = 128; |
| 1162 | const int max_seq = std::min(a: 2*max_tasks_per_batch, b: (int) llama_n_seq_max(ctx)); |
| 1163 | |
| 1164 | llama_batch batch = llama_batch_init(n_tokens: n_ctx, embd: 0, n_seq_max: 2); |
| 1165 | |
| 1166 | std::vector<float> tok_logits(n_vocab); |
| 1167 | // TODO: this could be made smaller; it's currently the worst-case size |
| 1168 | std::vector<float> batch_logits(size_t(n_ctx)*n_vocab); |
| 1169 | |
| 1170 | std::vector<std::pair<size_t, llama_token>> eval_pairs; |
| 1171 | std::vector<float> eval_results; |
| 1172 | std::vector<std::thread> workers(std::thread::hardware_concurrency()); |
| 1173 | |
| 1174 | int n_correct = 0; |
| 1175 | int n_done = 0; |
| 1176 | |
| 1177 | for (size_t i0 = 0; i0 < data.size(); i0++) { |
| 1178 | int n_cur = 0; |
| 1179 | |
| 1180 | size_t i1 = i0; |
| 1181 | size_t i_logits = 0; |
| 1182 | |
| 1183 | common_batch_clear(batch); |
| 1184 | |
| 1185 | while (n_cur + (int) data[i1].required_tokens <= n_ctx) { |
| 1186 | int n_logits = 0; |
| 1187 | const int s0 = 2*(i1 - i0); |
| 1188 | if (s0 + 2 > max_seq) { |
| 1189 | break; |
| 1190 | } |
| 1191 | |
| 1192 | for (size_t i = 0; i < data[i1].common_prefix; ++i) { |
| 1193 | common_batch_add(batch, id: data[i1].seq_tokens[0][i], pos: i, seq_ids: { s0 + 0, s0 + 1 }, logits: false); |
| 1194 | } |
| 1195 | batch.logits[batch.n_tokens - 1] = true; |
| 1196 | n_logits += 1; |
| 1197 | |
| 1198 | for (int s = 0; s < 2; ++s) { |
| 1199 | // TODO: end before the last token, no need to predict past the end of the sequences |
| 1200 | for (size_t i = data[i1].common_prefix; i < data[i1].seq_tokens[s].size(); ++i) { |
| 1201 | common_batch_add(batch, id: data[i1].seq_tokens[s][i], pos: i, seq_ids: { s0 + s }, logits: true); |
| 1202 | n_logits += 1; |
| 1203 | } |
| 1204 | } |
| 1205 | |
| 1206 | data[i1].i_logits = i_logits; |
| 1207 | i_logits += n_logits; |
| 1208 | |
| 1209 | n_cur += data[i1].required_tokens; |
| 1210 | if (++i1 == data.size()) { |
| 1211 | break; |
| 1212 | } |
| 1213 | } |
| 1214 | |
| 1215 | if (i0 == i1) { |
| 1216 | LOG_ERR("%s : task %zu does not fit in the context window (requires %lu tokens)\n" , __func__, i0, data[i0].required_tokens); |
| 1217 | return; |
| 1218 | } |
| 1219 | |
| 1220 | llama_memory_clear(mem: llama_get_memory(ctx), data: true); |
| 1221 | |
| 1222 | // decode all tasks [i0, i1) |
| 1223 | if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) { |
| 1224 | LOG_ERR("%s: llama_decode() failed\n" , __func__); |
| 1225 | return; |
| 1226 | } |
| 1227 | |
| 1228 | eval_pairs.clear(); |
| 1229 | for (size_t i = i0; i < i1; ++i) { |
| 1230 | auto & task = data[i]; |
| 1231 | |
| 1232 | const bool skip_choice = |
| 1233 | task.seq_tokens[0].size() - task.common_prefix > k_min_trailing_ctx && |
| 1234 | task.seq_tokens[1].size() - task.common_prefix > k_min_trailing_ctx; |
| 1235 | |
| 1236 | const auto& n_base1 = skip_choice ? task.n_base1 : task.common_prefix; |
| 1237 | const int last_1st = task.seq_tokens[0].size() - n_base1 > 1 ? 1 : 0; |
| 1238 | size_t li = n_base1 - task.common_prefix; |
| 1239 | for (size_t j = n_base1-1; j < task.seq_tokens[0].size()-1-last_1st; ++j) { |
| 1240 | eval_pairs.emplace_back(args: task.i_logits + li++, args&: task.seq_tokens[0][j+1]); |
| 1241 | } |
| 1242 | const auto& n_base2 = skip_choice ? task.n_base2 : task.common_prefix; |
| 1243 | const int last_2nd = task.seq_tokens[1].size() - n_base2 > 1 ? 1 : 0; |
| 1244 | // FIXME: this uses the wrong first logits when not skipping the choice word |
| 1245 | li = task.seq_tokens[0].size() - task.common_prefix + n_base2 - task.common_prefix; |
| 1246 | for (size_t j = n_base2-1; j < task.seq_tokens[1].size()-1-last_2nd; ++j) { |
| 1247 | eval_pairs.emplace_back(args: task.i_logits + li++, args&: task.seq_tokens[1][j+1]); |
| 1248 | } |
| 1249 | } |
| 1250 | compute_logprobs(batch_logits: batch_logits.data(), n_vocab, workers, eval_pairs, eval_results); |
| 1251 | |
| 1252 | size_t ir = 0; |
| 1253 | for (size_t i = i0; i < i1; ++i) { |
| 1254 | auto & task = data[i]; |
| 1255 | |
| 1256 | const bool skip_choice = |
| 1257 | task.seq_tokens[0].size() - task.common_prefix > k_min_trailing_ctx && |
| 1258 | task.seq_tokens[1].size() - task.common_prefix > k_min_trailing_ctx; |
| 1259 | |
| 1260 | float score_1st = 0; |
| 1261 | const auto& n_base1 = skip_choice ? task.n_base1 : task.common_prefix; |
| 1262 | const int last_1st = task.seq_tokens[0].size() - n_base1 > 1 ? 1 : 0; |
| 1263 | for (size_t j = n_base1-1; j < task.seq_tokens[0].size()-1-last_1st; ++j) { |
| 1264 | score_1st += eval_results[ir++]; |
| 1265 | } |
| 1266 | score_1st /= (task.seq_tokens[0].size() - n_base1 - last_1st); |
| 1267 | |
| 1268 | float score_2nd = 0; |
| 1269 | const auto& n_base2 = skip_choice ? task.n_base2 : task.common_prefix; |
| 1270 | const int last_2nd = task.seq_tokens[1].size() - n_base2 > 1 ? 1 : 0; |
| 1271 | for (size_t j = n_base2-1; j < task.seq_tokens[1].size()-1-last_2nd; ++j) { |
| 1272 | score_2nd += eval_results[ir++]; |
| 1273 | } |
| 1274 | score_2nd /= (task.seq_tokens[1].size() - n_base2 - last_2nd); |
| 1275 | |
| 1276 | int result = score_1st > score_2nd ? 1 : 2; |
| 1277 | |
| 1278 | if (result == task.answer) { |
| 1279 | ++n_correct; |
| 1280 | } |
| 1281 | ++n_done; |
| 1282 | |
| 1283 | // print the accumulated accuracy mean x 100 |
| 1284 | LOG("%zu\t%.4lf\t%10.6f %10.6f %d %d\n" , i+1, 100.0 * n_correct/n_done, score_1st, score_2nd, result, task.answer); |
| 1285 | } |
| 1286 | |
| 1287 | i0 = i1 - 1; |
| 1288 | } |
| 1289 | |
| 1290 | LOG("\n" ); |
| 1291 | |
| 1292 | if (n_done < 100) return; |
| 1293 | |
| 1294 | const float p = 1.f*n_correct/n_done; |
| 1295 | const float sigma = 100.f*sqrt(x: p*(1-p)/(n_done-1)); |
| 1296 | |
| 1297 | LOG_INF("Final Winogrande score(%d tasks): %.4lf +/- %.4lf\n" , n_done, 100*p, sigma); |
| 1298 | } |
| 1299 | |
| 1300 | static bool deserialize_string(std::istream & in, std::string & str) { |
| 1301 | uint32_t size; |
| 1302 | if (!in.read(s: (char *)&size, n: sizeof(size)).fail()) { |
| 1303 | str.resize(n: size); |
| 1304 | if (!in.read(s: (char *)&str[0], n: size).fail()) return true; |
| 1305 | } |
| 1306 | return false; |
| 1307 | } |
| 1308 | |
| 1309 | struct multiple_choice_answers { |
| 1310 | std::vector<std::string> answers; |
| 1311 | std::vector<int> labels; |
| 1312 | bool deserialize(std::istream& in) { |
| 1313 | uint32_t n; |
| 1314 | in.read(s: (char *)&n, n: sizeof(n)); |
| 1315 | if (in.fail() || n > 100) return false; // 100 as max. number of answers should be good enough for any practical purpose |
| 1316 | answers.resize(new_size: n); |
| 1317 | labels.resize(new_size: n); |
| 1318 | for (auto& a : answers) { |
| 1319 | if (!deserialize_string(in, str&: a)) return false; |
| 1320 | } |
| 1321 | in.read(s: (char *)labels.data(), n: n*sizeof(int)); |
| 1322 | return !in.fail(); |
| 1323 | } |
| 1324 | }; |
| 1325 | |
| 1326 | struct multiple_choice_task { |
| 1327 | std::string question; // the question (or context that needs to be continued) |
| 1328 | multiple_choice_answers mc1; // possible answers (continuations) with a single correct answer |
| 1329 | multiple_choice_answers mc2; // possible answers (continuations) with multiple correct answers - not handled yet |
| 1330 | bool deserialize(std::istream& in) { |
| 1331 | if (!deserialize_string(in, str&: question)) return false; |
| 1332 | return mc1.deserialize(in) && mc2.deserialize(in); |
| 1333 | } |
| 1334 | |
| 1335 | // For evaluation |
| 1336 | size_t i_logits; // starting index of logits in the llama_batch |
| 1337 | size_t common_prefix; // max number of initial tokens that are the same in all sentences |
| 1338 | size_t required_tokens; // needed number of tokens to evaluate all answers |
| 1339 | std::vector<std::vector<llama_token>> seq_tokens; |
| 1340 | std::vector<float> log_probs; |
| 1341 | }; |
| 1342 | |
| 1343 | static bool multiple_choice_prepare_one_task(llama_context * ctx, multiple_choice_task& task, bool log_error) { |
| 1344 | if (task.question.empty() || task.mc1.answers.empty()) { |
| 1345 | if (log_error) { |
| 1346 | LOG_ERR("%s: found bad task with empty question and/or answers\n" , __func__); |
| 1347 | } |
| 1348 | return false; |
| 1349 | } |
| 1350 | task.seq_tokens.reserve(n: task.mc1.answers.size()); |
| 1351 | for (auto& answer : task.mc1.answers) { |
| 1352 | if (answer.empty()) { |
| 1353 | if (log_error) { |
| 1354 | LOG_ERR("%s: found empty answer\n" , __func__); |
| 1355 | } |
| 1356 | return false; |
| 1357 | } |
| 1358 | task.seq_tokens.emplace_back(args: ::common_tokenize(ctx, text: task.question + " " + answer, add_special: true)); |
| 1359 | } |
| 1360 | auto min_len = task.seq_tokens.front().size(); |
| 1361 | for (auto& seq : task.seq_tokens) { |
| 1362 | min_len = std::min(a: min_len, b: seq.size()); |
| 1363 | } |
| 1364 | task.common_prefix = 0; |
| 1365 | for (size_t k = 0; k < min_len; ++k) { |
| 1366 | auto token = task.seq_tokens[0][k]; |
| 1367 | bool all_same = true; |
| 1368 | for (size_t i = 1; i < task.seq_tokens.size(); ++i) { |
| 1369 | if (task.seq_tokens[i][k] != token) { |
| 1370 | all_same = false; |
| 1371 | break; |
| 1372 | } |
| 1373 | } |
| 1374 | if (!all_same) { |
| 1375 | break; |
| 1376 | } |
| 1377 | ++task.common_prefix; |
| 1378 | } |
| 1379 | task.required_tokens = task.common_prefix; |
| 1380 | for (auto& seq : task.seq_tokens) { |
| 1381 | task.required_tokens += seq.size() - task.common_prefix; |
| 1382 | } |
| 1383 | return true; |
| 1384 | } |
| 1385 | |
| 1386 | // |
| 1387 | // Calculates score for multiple choice tasks with single correct answer from prompt. |
| 1388 | // Commonly used LLM evaluation metrics of this type are |
| 1389 | // * ARC |
| 1390 | // * HellaSwag |
| 1391 | // * MMLU |
| 1392 | // * TruthfulQA |
| 1393 | // |
| 1394 | // Validation datasets for these 4 tests can be found at |
| 1395 | // https://huggingface.co/datasets/ikawrakow/validation-datasets-for-llama.cpp |
| 1396 | // The data for these datasets was extracted from |
| 1397 | // git@hf.co:datasets/allenai/ai2_arc |
| 1398 | // https://github.com/rowanz/hellaswag/blob/master/data/hellaswag_val.jsonl |
| 1399 | // git@hf.co:datasets/Stevross/mmlu |
| 1400 | // https://huggingface.co/datasets/truthful_qa |
| 1401 | // |
| 1402 | static void multiple_choice_score(llama_context * ctx, const common_params & params) { |
| 1403 | const llama_model * model = llama_get_model(ctx); |
| 1404 | const llama_vocab * vocab = llama_model_get_vocab(model); |
| 1405 | |
| 1406 | std::istringstream strstream(params.prompt); |
| 1407 | uint32_t n_task; |
| 1408 | strstream.read(s: (char *)&n_task, n: sizeof(n_task)); |
| 1409 | if (strstream.fail() || n_task == 0) { |
| 1410 | LOG_ERR("%s: no tasks\n" , __func__); |
| 1411 | return; |
| 1412 | } |
| 1413 | LOG_INF("%s: there are %u tasks in prompt\n" , __func__, n_task); |
| 1414 | std::vector<uint32_t> task_pos(n_task); |
| 1415 | strstream.read(s: (char *)task_pos.data(), n: task_pos.size()*sizeof(uint32_t)); |
| 1416 | if (strstream.fail()) { |
| 1417 | LOG_ERR("%s: failed to read task positions from prompt\n" , __func__); |
| 1418 | return; |
| 1419 | } |
| 1420 | |
| 1421 | std::vector<multiple_choice_task> tasks; |
| 1422 | if (params.multiple_choice_tasks == 0 || params.multiple_choice_tasks >= (size_t)n_task) { |
| 1423 | // Use all tasks |
| 1424 | tasks.resize(new_size: n_task); |
| 1425 | LOG_INF("%s: reading tasks" , __func__); |
| 1426 | int n_dot = std::max(a: (int) n_task/100, b: 1); |
| 1427 | int i = 0; |
| 1428 | for (auto& task : tasks) { |
| 1429 | ++i; |
| 1430 | if (!task.deserialize(in&: strstream)) { |
| 1431 | LOG_ERR("%s: failed to read task %d of %u\n" , __func__, i, n_task); |
| 1432 | return; |
| 1433 | } |
| 1434 | if (i%n_dot == 0) LOG("." ); |
| 1435 | } |
| 1436 | LOG("done\n" ); |
| 1437 | } |
| 1438 | else { |
| 1439 | LOG_INF("%s: selecting %zu random tasks from %u tasks available\n" , __func__, params.multiple_choice_tasks, n_task); |
| 1440 | std::mt19937 rng(1); |
| 1441 | std::vector<int> aux(n_task); |
| 1442 | for (uint32_t i = 0; i < n_task; ++i) aux[i] = i; |
| 1443 | float scale = 1.f/(1.f + (float)std::mt19937::max()); |
| 1444 | tasks.resize(new_size: params.multiple_choice_tasks); |
| 1445 | for (auto& task : tasks) { |
| 1446 | int j = (int)(scale * rng() * aux.size()); |
| 1447 | int idx = aux[j]; |
| 1448 | aux[j] = aux.back(); |
| 1449 | aux.pop_back(); |
| 1450 | strstream.seekg(task_pos[idx], std::ios::beg); |
| 1451 | if (!task.deserialize(in&: strstream)) { |
| 1452 | LOG_ERR("%s: failed to read task %d at position %u\n" , __func__, idx, task_pos[idx]); |
| 1453 | return; |
| 1454 | } |
| 1455 | } |
| 1456 | n_task = params.multiple_choice_tasks; |
| 1457 | } |
| 1458 | |
| 1459 | LOG_INF("%s: preparing task data" , __func__); |
| 1460 | if (n_task > 500) { |
| 1461 | LOG("..." ); |
| 1462 | std::atomic<int> counter(0); |
| 1463 | std::atomic<int> n_bad(0); |
| 1464 | auto prepare = [&counter, &n_bad, &tasks, ctx] () { |
| 1465 | int num_tasks = tasks.size(); |
| 1466 | int n_bad_local = 0; |
| 1467 | while (true) { |
| 1468 | int first = counter.fetch_add(K_TOKEN_CHUNK); |
| 1469 | if (first >= num_tasks) { |
| 1470 | if (n_bad_local > 0) n_bad += n_bad_local; |
| 1471 | break; |
| 1472 | } |
| 1473 | int last = std::min(a: first + K_TOKEN_CHUNK, b: num_tasks); |
| 1474 | for (int i = first; i < last; ++i) { |
| 1475 | if (!multiple_choice_prepare_one_task(ctx, task&: tasks[i], log_error: false)) ++n_bad_local; |
| 1476 | } |
| 1477 | } |
| 1478 | }; |
| 1479 | size_t max_thread = std::thread::hardware_concurrency(); |
| 1480 | max_thread = std::min(a: max_thread, b: (tasks.size() + K_TOKEN_CHUNK - 1)/K_TOKEN_CHUNK); |
| 1481 | std::vector<std::thread> workers(max_thread-1); |
| 1482 | for (auto& w : workers) w = std::thread(prepare); |
| 1483 | prepare(); |
| 1484 | for (auto& w : workers) w.join(); |
| 1485 | LOG("done\n" ); |
| 1486 | int nbad = n_bad; |
| 1487 | if (nbad > 0) { |
| 1488 | LOG_ERR("%s: found %d malformed tasks\n" , __func__, nbad); |
| 1489 | return; |
| 1490 | } |
| 1491 | } else { |
| 1492 | int n_dot = std::max(a: (int) n_task/100, b: 1); |
| 1493 | int i_task = 0; |
| 1494 | for (auto& task : tasks) { |
| 1495 | ++i_task; |
| 1496 | if (!multiple_choice_prepare_one_task(ctx, task, log_error: true)) { |
| 1497 | return; |
| 1498 | } |
| 1499 | if (i_task%n_dot == 0) { |
| 1500 | LOG("." ); |
| 1501 | } |
| 1502 | } |
| 1503 | LOG("done\n" ); |
| 1504 | } |
| 1505 | |
| 1506 | LOG_INF("%s : calculating TruthfulQA score over %zu tasks.\n" , __func__, tasks.size()); |
| 1507 | |
| 1508 | LOG("\ntask\tacc_norm\n" ); |
| 1509 | |
| 1510 | const int n_ctx = llama_n_ctx(ctx); |
| 1511 | const int n_batch = params.n_batch; |
| 1512 | |
| 1513 | const int n_vocab = llama_vocab_n_tokens(vocab); |
| 1514 | |
| 1515 | const int max_tasks_per_batch = 32; |
| 1516 | const int max_seq = std::min(a: 4*max_tasks_per_batch, b: (int) llama_n_seq_max(ctx)); |
| 1517 | |
| 1518 | llama_batch batch = llama_batch_init(n_tokens: n_ctx, embd: 0, n_seq_max: max_seq); |
| 1519 | |
| 1520 | std::vector<float> tok_logits(n_vocab); |
| 1521 | std::vector<float> batch_logits(size_t(n_ctx)*n_vocab); |
| 1522 | |
| 1523 | std::vector<std::pair<size_t, llama_token>> eval_pairs; |
| 1524 | std::vector<float> eval_results; |
| 1525 | std::vector<std::thread> workers(std::thread::hardware_concurrency()); |
| 1526 | std::vector<int> batch_indeces; |
| 1527 | |
| 1528 | int n_done = 0; |
| 1529 | int n_correct = 0; |
| 1530 | int n_tot_answers = 0; |
| 1531 | |
| 1532 | for (size_t i0 = 0; i0 < tasks.size(); i0++) { |
| 1533 | int n_cur = 0; |
| 1534 | |
| 1535 | size_t i1 = i0; |
| 1536 | size_t i_logits = 0; // this tells us how many logits were needed before this point in the batch |
| 1537 | |
| 1538 | common_batch_clear(batch); |
| 1539 | |
| 1540 | // batch as much tasks as possible into the available context |
| 1541 | // each task has 4 unique sequence ids - one for each ending |
| 1542 | // the common prefix is shared among the 4 sequences to save tokens |
| 1543 | // we extract logits only from the last common token and from all ending tokens of each sequence |
| 1544 | int s0 = 0; |
| 1545 | while (n_cur + (int) tasks[i1].required_tokens <= n_ctx) { |
| 1546 | auto& cur_task = tasks[i1]; |
| 1547 | int n_logits = 0; |
| 1548 | |
| 1549 | int num_answers = cur_task.seq_tokens.size(); |
| 1550 | if (s0 + num_answers > max_seq) { |
| 1551 | if (s0 == 0) { |
| 1552 | LOG_ERR("%s : task %zu requires a higher -np|--parallel value (at least %d)\n" , __func__, i0, num_answers); |
| 1553 | return; |
| 1554 | } |
| 1555 | break; |
| 1556 | } |
| 1557 | |
| 1558 | if (int(batch_indeces.size()) != num_answers) { |
| 1559 | batch_indeces.resize(new_size: num_answers); |
| 1560 | } |
| 1561 | |
| 1562 | for (int s = 0; s < num_answers; ++s) { |
| 1563 | batch_indeces[s] = s0 + s; |
| 1564 | } |
| 1565 | |
| 1566 | for (size_t i = 0; i < cur_task.common_prefix; ++i) { |
| 1567 | //llama_batch_add(batch, cur_task.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3}, false); |
| 1568 | common_batch_add(batch, id: cur_task.seq_tokens[0][i], pos: i, seq_ids: batch_indeces, logits: false); |
| 1569 | } |
| 1570 | batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix |
| 1571 | n_logits += 1; |
| 1572 | |
| 1573 | for (int s = 0; s < int(cur_task.seq_tokens.size()); ++s) { |
| 1574 | const size_t seq_tokens_size = cur_task.seq_tokens[s].size(); |
| 1575 | // TODO: don't evaluate the last token of each sequence |
| 1576 | for (size_t i = cur_task.common_prefix; i < seq_tokens_size; ++i) { |
| 1577 | const bool needs_logits = i < seq_tokens_size - 1; |
| 1578 | common_batch_add(batch, id: cur_task.seq_tokens[s][i], pos: i, seq_ids: { s0 + s }, logits: needs_logits); |
| 1579 | n_logits += needs_logits; |
| 1580 | } |
| 1581 | } |
| 1582 | |
| 1583 | s0 += num_answers; |
| 1584 | |
| 1585 | cur_task.i_logits = i_logits; |
| 1586 | i_logits += n_logits; |
| 1587 | |
| 1588 | n_cur += cur_task.required_tokens; |
| 1589 | if (++i1 == tasks.size()) { |
| 1590 | break; |
| 1591 | } |
| 1592 | } |
| 1593 | |
| 1594 | if (i0 == i1) { |
| 1595 | LOG_ERR("%s : task %zu does not fit in the context window (requires %lu tokens)\n" , __func__, i0, tasks[i0].required_tokens); |
| 1596 | return; |
| 1597 | } |
| 1598 | |
| 1599 | llama_memory_clear(mem: llama_get_memory(ctx), data: true); |
| 1600 | |
| 1601 | // decode all tasks [i0, i1) |
| 1602 | if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) { |
| 1603 | LOG_ERR("%s: llama_decode() failed\n" , __func__); |
| 1604 | return; |
| 1605 | } |
| 1606 | |
| 1607 | // Compute log-probs in parallel |
| 1608 | // First we collect all tasks |
| 1609 | eval_pairs.clear(); |
| 1610 | for (size_t i = i0; i < i1; ++i) { |
| 1611 | auto& cur_task = tasks[i]; |
| 1612 | size_t li = 1; // skip the last logit of the common prefix (computed separately below) |
| 1613 | for (int s = 0; s < int(cur_task.seq_tokens.size()); ++s) { |
| 1614 | for (size_t j = cur_task.common_prefix; j < cur_task.seq_tokens[s].size() - 1; j++) { |
| 1615 | eval_pairs.emplace_back(args: cur_task.i_logits + li++, args&: cur_task.seq_tokens[s][j + 1]); |
| 1616 | } |
| 1617 | } |
| 1618 | } |
| 1619 | // Then we do the actual calculation |
| 1620 | compute_logprobs(batch_logits: batch_logits.data(), n_vocab, workers, eval_pairs, eval_results); |
| 1621 | |
| 1622 | size_t ir = 0; |
| 1623 | |
| 1624 | // compute the logprobs for each ending of the decoded tasks |
| 1625 | for (size_t i = i0; i < i1; ++i) { |
| 1626 | auto & cur_task = tasks[i]; |
| 1627 | //LOG("==== Evaluating <%s> with correct answer ", cur_task.question.c_str()); |
| 1628 | //for (int j = 0; j < int(cur_task.mc1.labels.size()); ++j) { |
| 1629 | // if (cur_task.mc1.labels[j] == 1) { |
| 1630 | // LOG("%d", j+1); |
| 1631 | // } |
| 1632 | //} |
| 1633 | //LOG("\n common_prefix: %zu\n", cur_task.common_prefix); |
| 1634 | |
| 1635 | // get the logits of the last token of the common prefix |
| 1636 | std::memcpy(dest: tok_logits.data(), src: batch_logits.data() + cur_task.i_logits*n_vocab, n: n_vocab*sizeof(float)); |
| 1637 | |
| 1638 | const auto first_probs = softmax(logits: tok_logits); |
| 1639 | |
| 1640 | cur_task.log_probs.resize(new_size: cur_task.seq_tokens.size()); |
| 1641 | for (int s = 0; s < int(cur_task.seq_tokens.size()); ++s) { |
| 1642 | size_t count = 1; |
| 1643 | float log_prob = std::log(x: first_probs[cur_task.seq_tokens[s][cur_task.common_prefix]]); |
| 1644 | for (size_t j = cur_task.common_prefix; j < cur_task.seq_tokens[s].size() - 1; j++) { |
| 1645 | //LOG(" %zu %g\n", ir, eval_results[ir]); |
| 1646 | ++count; |
| 1647 | log_prob += eval_results[ir++]; |
| 1648 | } |
| 1649 | cur_task.log_probs[s] = log_prob / count; |
| 1650 | //LOG(" Final: %g\n", log_prob / count); |
| 1651 | //LOG(" <%s> : %g\n", cur_task.mc1.answers[s].c_str(), log_prob/count); |
| 1652 | } |
| 1653 | |
| 1654 | // Find the ending with maximum logprob |
| 1655 | size_t logprob_max_idx = 0; |
| 1656 | float logprob_max_val = cur_task.log_probs[0]; |
| 1657 | for (size_t s = 1; s < cur_task.log_probs.size(); s++) { |
| 1658 | if (cur_task.log_probs[s] > logprob_max_val) { |
| 1659 | logprob_max_val = cur_task.log_probs[s]; |
| 1660 | logprob_max_idx = s; |
| 1661 | } |
| 1662 | } |
| 1663 | |
| 1664 | n_tot_answers += cur_task.log_probs.size(); |
| 1665 | if (cur_task.mc1.labels[logprob_max_idx] == 1) { |
| 1666 | ++n_correct; |
| 1667 | } |
| 1668 | ++n_done; |
| 1669 | |
| 1670 | // Print the accumulated accuracy mean x 100 |
| 1671 | LOG("%d\t%.8lf\n" , n_done, 100.*n_correct/n_done); |
| 1672 | } |
| 1673 | |
| 1674 | i0 = i1 - 1; |
| 1675 | } |
| 1676 | |
| 1677 | llama_batch_free(batch); |
| 1678 | |
| 1679 | if (n_done < 100 && (params.multiple_choice_tasks != 0 && params.multiple_choice_tasks < (size_t)n_task)) return; |
| 1680 | |
| 1681 | float p = 1.f*n_correct/n_done; |
| 1682 | float sigma = sqrt(x: p*(1-p)/(n_done-1)); |
| 1683 | LOG("\n" ); |
| 1684 | LOG_INF("Final result: %.4f +/- %.4f\n" , 100.f*p, 100.f*sigma); |
| 1685 | p = 1.f*n_done/n_tot_answers; |
| 1686 | sigma = sqrt(x: p*(1-p)/(n_done-1)); |
| 1687 | LOG_INF("Random chance: %.4f +/- %.4f\n" , 100.f*p, 100.f*sigma); |
| 1688 | |
| 1689 | LOG_INF("\n" ); |
| 1690 | } |
| 1691 | |
| 1692 | static void kl_divergence(llama_context * ctx, const common_params & params) { |
| 1693 | const llama_model * model = llama_get_model(ctx); |
| 1694 | const llama_vocab * vocab = llama_model_get_vocab(model); |
| 1695 | |
| 1696 | if (params.logits_file.empty()) { |
| 1697 | LOG_ERR("%s: you must provide a name of a file containing the log probabilities of the base model\n" , __func__); |
| 1698 | return; |
| 1699 | } |
| 1700 | std::ifstream in(params.logits_file.c_str(), std::ios::binary); |
| 1701 | if (!in) { |
| 1702 | LOG_ERR("%s: failed to open %s\n" , __func__, params.logits_file.c_str()); |
| 1703 | return; |
| 1704 | } |
| 1705 | { |
| 1706 | char check[9]; check[8] = 0; |
| 1707 | in.read(s: check, n: 8); |
| 1708 | if (in.fail() || strncmp(s1: "_logits_" , s2: check, n: 8) != 0) { |
| 1709 | LOG_ERR("%s: %s does not look like a file containing log-probabilities\n" , __func__, params.logits_file.c_str()); |
| 1710 | return; |
| 1711 | } |
| 1712 | } |
| 1713 | |
| 1714 | uint32_t n_ctx; |
| 1715 | in.read(s: (char *)&n_ctx, n: sizeof(n_ctx)); |
| 1716 | if (n_ctx > llama_n_ctx(ctx)) { |
| 1717 | LOG_ERR("%s: %s has been computed with %u, while the current context is %d. Increase it with -c and retry\n" , |
| 1718 | __func__, params.logits_file.c_str(), n_ctx, params.n_ctx); |
| 1719 | } |
| 1720 | |
| 1721 | int n_vocab; |
| 1722 | int n_chunk; |
| 1723 | in.read(s: (char *)&n_vocab, n: sizeof(n_vocab)); |
| 1724 | in.read(s: (char *)&n_chunk, n: sizeof(n_chunk)); |
| 1725 | if (in.fail()) { |
| 1726 | LOG_ERR("%s: failed reading n_vocab, n_chunk from %s\n" , __func__, params.logits_file.c_str()); |
| 1727 | return; |
| 1728 | } |
| 1729 | if (n_vocab != llama_vocab_n_tokens(vocab)) { |
| 1730 | LOG_ERR("%s: inconsistent vocabulary (%d vs %d)\n" , __func__, n_vocab, llama_vocab_n_tokens(vocab)); |
| 1731 | } |
| 1732 | |
| 1733 | std::vector<llama_token> tokens(size_t(n_ctx) * n_chunk); |
| 1734 | if (in.read(s: (char *)tokens.data(), n: tokens.size()*sizeof(tokens[0])).fail()) { |
| 1735 | LOG_ERR("%s: failed reading evaluation tokens from %s\n" , __func__, params.logits_file.c_str()); |
| 1736 | return; |
| 1737 | } |
| 1738 | |
| 1739 | const int n_batch = params.n_batch; |
| 1740 | const int num_batches = (n_ctx + n_batch - 1)/n_batch; |
| 1741 | const int nv = 2*((n_vocab + 1)/2) + 4; |
| 1742 | const bool add_bos = llama_vocab_get_add_bos(vocab); |
| 1743 | GGML_ASSERT(!llama_vocab_get_add_eos(vocab)); |
| 1744 | |
| 1745 | std::vector<uint16_t> log_probs_uint16(size_t(n_ctx - 1 - n_ctx/2) * nv); |
| 1746 | std::vector<float> kld_values(size_t(n_ctx - 1 - n_ctx/2)*n_chunk); |
| 1747 | std::vector<float> p_diff_values(size_t(n_ctx - 1 - n_ctx/2)*n_chunk); |
| 1748 | std::vector<float> logits; |
| 1749 | if (num_batches > 1) { |
| 1750 | logits.reserve(n: size_t(n_ctx) * n_vocab); |
| 1751 | } |
| 1752 | |
| 1753 | std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1); |
| 1754 | |
| 1755 | auto mean_and_uncertainty = [] (double sum, double sum2, size_t count) { |
| 1756 | if (count < 1) { |
| 1757 | return std::make_pair(x: 0., y: 0.); |
| 1758 | } |
| 1759 | double f = sum/count; |
| 1760 | double df = sum2/count - f*f; |
| 1761 | df = df > 0 && count > 10 ? sqrt(x: df/(count-1)) : 0.; |
| 1762 | return std::make_pair(x&: f, y&: df); |
| 1763 | }; |
| 1764 | auto covariance = [] (double suma, double sumb, double sumab, size_t count) { |
| 1765 | if (count < 10) { |
| 1766 | return 0.0; |
| 1767 | } |
| 1768 | double var = sumab/count - (suma/count)*(sumb/count); |
| 1769 | var /= count - 1; |
| 1770 | return var; |
| 1771 | }; |
| 1772 | |
| 1773 | kl_divergence_result kld; |
| 1774 | auto kld_ptr = kld_values.data(); |
| 1775 | auto p_diff_ptr = p_diff_values.data(); |
| 1776 | |
| 1777 | for (int i = 0; i < n_chunk; ++i) { |
| 1778 | const int start = i * n_ctx; |
| 1779 | const int end = start + n_ctx; |
| 1780 | |
| 1781 | const auto t_start = std::chrono::high_resolution_clock::now(); |
| 1782 | |
| 1783 | if (in.read(s: (char *)log_probs_uint16.data(), n: log_probs_uint16.size()*sizeof(uint16_t)).fail()) { |
| 1784 | LOG_ERR("%s: failed reading log-probs for chunk %d\n" , __func__, i); |
| 1785 | return; |
| 1786 | } |
| 1787 | |
| 1788 | // clear the KV cache |
| 1789 | llama_memory_clear(mem: llama_get_memory(ctx), data: true); |
| 1790 | |
| 1791 | llama_batch batch = llama_batch_init(n_tokens: n_batch, embd: 0, n_seq_max: 1); |
| 1792 | |
| 1793 | for (int j = 0; j < num_batches; ++j) { |
| 1794 | const int batch_start = start + j * n_batch; |
| 1795 | const int batch_size = std::min(a: end - batch_start, b: n_batch); |
| 1796 | |
| 1797 | // save original token and restore it after eval |
| 1798 | const auto token_org = tokens[batch_start]; |
| 1799 | |
| 1800 | // add BOS token for the first batch of each chunk |
| 1801 | if (add_bos && j == 0) { |
| 1802 | tokens[batch_start] = llama_vocab_bos(vocab); |
| 1803 | } |
| 1804 | |
| 1805 | common_batch_clear(batch); |
| 1806 | for (int i = 0; i < batch_size; i++) { |
| 1807 | common_batch_add(batch, id: tokens[batch_start + i], pos: j*n_batch + i, seq_ids: {0}, logits: true); |
| 1808 | } |
| 1809 | |
| 1810 | if (llama_decode(ctx, batch)) { |
| 1811 | LOG_ERR("%s : failed to eval\n" , __func__); |
| 1812 | llama_batch_free(batch); |
| 1813 | return; |
| 1814 | } |
| 1815 | |
| 1816 | // restore the original token in case it was set to BOS |
| 1817 | tokens[batch_start] = token_org; |
| 1818 | |
| 1819 | if (num_batches > 1) { |
| 1820 | const auto * batch_logits = llama_get_logits(ctx); |
| 1821 | logits.insert(position: logits.end(), first: batch_logits, last: batch_logits + size_t(batch_size) * n_vocab); |
| 1822 | } |
| 1823 | } |
| 1824 | |
| 1825 | llama_batch_free(batch); |
| 1826 | |
| 1827 | const auto t_end = std::chrono::high_resolution_clock::now(); |
| 1828 | |
| 1829 | if (i == 0) { |
| 1830 | const float t_total = std::chrono::duration<float>(t_end - t_start).count(); |
| 1831 | LOG_INF("%s: %.2f seconds per pass - ETA " , __func__, t_total); |
| 1832 | int total_seconds = (int)(t_total * n_chunk); |
| 1833 | if (total_seconds >= 60*60) { |
| 1834 | LOG("%d hours " , total_seconds / (60*60)); |
| 1835 | total_seconds = total_seconds % (60*60); |
| 1836 | } |
| 1837 | LOG("%.2f minutes\n" , total_seconds / 60.0); |
| 1838 | } |
| 1839 | LOG("\n" ); |
| 1840 | LOG("chunk PPL ln(PPL(Q)/PPL(base)) KL Divergence Δp RMS Same top p\n" ); |
| 1841 | |
| 1842 | const int first = n_ctx/2; |
| 1843 | const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits(ctx); |
| 1844 | process_logits(n_vocab, logits: all_logits + size_t(first)*n_vocab, tokens: tokens.data() + start + first, n_token: n_ctx - 1 - first, |
| 1845 | workers, base_log_probs: log_probs_uint16, kld, kld_values: kld_ptr, p_diff_values: p_diff_ptr); |
| 1846 | p_diff_ptr += n_ctx - 1 - first; |
| 1847 | kld_ptr += n_ctx - 1 - first; |
| 1848 | |
| 1849 | LOG("%4d" , i+1); |
| 1850 | |
| 1851 | auto log_ppl = mean_and_uncertainty(kld.sum_nll, kld.sum_nll2, kld.count); |
| 1852 | const double ppl_val = exp(x: log_ppl.first); |
| 1853 | const double ppl_unc = ppl_val * log_ppl.second; // ppl_unc = sqrt( (dexp(x) / dx) ** 2 * log_ppl.second ** 2 ) |
| 1854 | LOG(" %9.4lf ± %9.4lf" , ppl_val, ppl_unc); |
| 1855 | |
| 1856 | auto log_ppl_base = mean_and_uncertainty(kld.sum_nll_base, kld.sum_nll_base2, kld.count); |
| 1857 | const double log_ppl_cov = covariance(kld.sum_nll, kld.sum_nll_base, kld.sum_nll_nll_base, kld.count); |
| 1858 | const double log_ppl_ratio_val = log_ppl.first - log_ppl_base.first; |
| 1859 | const double log_ppl_ratio_unc = sqrt(x: log_ppl.second*log_ppl.second + log_ppl_base.second*log_ppl_base.second - 2.0*log_ppl_cov); |
| 1860 | LOG(" %10.5lf ± %10.5lf" , log_ppl_ratio_val, log_ppl_ratio_unc); |
| 1861 | |
| 1862 | auto kl_div = mean_and_uncertainty(kld.sum_kld, kld.sum_kld2, kld.count); |
| 1863 | LOG(" %10.5lf ± %10.5lf" , kl_div.first, kl_div.second); |
| 1864 | |
| 1865 | auto p_diff_mse = mean_and_uncertainty(kld.sum_p_diff2, kld.sum_p_diff4, kld.count); |
| 1866 | const double p_diff_rms_val = sqrt(x: p_diff_mse.first); |
| 1867 | const double p_diff_rms_unc = 0.5/p_diff_rms_val * p_diff_mse.second; |
| 1868 | LOG(" %6.3lf ± %6.3lf %%" , 100.0*p_diff_rms_val, 100.0*p_diff_rms_unc); |
| 1869 | |
| 1870 | double p_top_val = 1.*kld.n_same_top/kld.count; |
| 1871 | double p_top_unc = sqrt(x: p_top_val*(1 - p_top_val)/(kld.count - 1)); |
| 1872 | LOG(" %6.3lf ± %6.3lf %%" , 100.0*p_top_val, 100.0*p_top_unc); |
| 1873 | |
| 1874 | LOG("\n" ); |
| 1875 | |
| 1876 | logits.clear(); |
| 1877 | } |
| 1878 | LOG("\n" ); |
| 1879 | |
| 1880 | if (kld.count < 100) return; // we do not wish to do statistics on so few values |
| 1881 | |
| 1882 | std::sort(first: kld_values.begin(), last: kld_values.end()); |
| 1883 | std::sort(first: p_diff_values.begin(), last: p_diff_values.end()); |
| 1884 | |
| 1885 | LOG("====== Perplexity statistics ======\n" ); |
| 1886 | |
| 1887 | auto log_ppl = mean_and_uncertainty(kld.sum_nll, kld.sum_nll2, kld.count); |
| 1888 | const double ppl_val = exp(x: log_ppl.first); |
| 1889 | const double ppl_unc = ppl_val * log_ppl.second; // ppl_unc = sqrt( (dexp(x) / dx) ** 2 * log_ppl.second ** 2 ) |
| 1890 | LOG("Mean PPL(Q) : %10.6lf ± %10.6lf\n" , ppl_val, ppl_unc); |
| 1891 | |
| 1892 | auto log_ppl_base = mean_and_uncertainty(kld.sum_nll_base, kld.sum_nll_base2, kld.count); |
| 1893 | const double ppl_base_val = exp(x: log_ppl_base.first); |
| 1894 | const double ppl_base_unc = ppl_base_val * log_ppl_base.second; // ppl_base_unc = sqrt( (dexp(x) / dx) ** 2 * log_ppl_base.second ** 2 ) |
| 1895 | LOG("Mean PPL(base) : %10.6lf ± %10.6lf\n" , ppl_base_val, ppl_base_unc); |
| 1896 | |
| 1897 | const double log_ppl_cov = covariance(kld.sum_nll, kld.sum_nll_base, kld.sum_nll_nll_base, kld.count); |
| 1898 | // LOG("Cov(ln(PPL(Q)), ln(PPL(base))): %10.6lf\n", log_ppl_cov); |
| 1899 | const double log_ppl_cor = log_ppl_cov / (log_ppl.second*log_ppl_base.second); |
| 1900 | LOG("Cor(ln(PPL(Q)), ln(PPL(base))): %6.2lf%%\n" , 100.0*log_ppl_cor); |
| 1901 | |
| 1902 | const double log_ppl_ratio_val = log_ppl.first - log_ppl_base.first; |
| 1903 | const double log_ppl_ratio_unc = sqrt(x: log_ppl.second*log_ppl.second + log_ppl_base.second*log_ppl_base.second - 2.0*log_ppl_cov); |
| 1904 | LOG("Mean ln(PPL(Q)/PPL(base)) : %10.6lf ± %10.6lf\n" , log_ppl_ratio_val, log_ppl_ratio_unc); |
| 1905 | |
| 1906 | const double ppl_ratio_val = exp(x: log_ppl_ratio_val); |
| 1907 | const double ppl_ratio_unc = ppl_ratio_val * log_ppl_ratio_unc; // ppl_ratio_unc = sqrt( (dexp(x) / dx) ** 2 * log_ppl_ratio.second ** 2 ) |
| 1908 | LOG("Mean PPL(Q)/PPL(base) : %10.6lf ± %10.6lf\n" , ppl_ratio_val, ppl_ratio_unc); |
| 1909 | |
| 1910 | const double ppl_cov = ppl_val * ppl_base_val * log_ppl_cov; |
| 1911 | const double ppl_diff_val = ppl_val - ppl_base_val; |
| 1912 | const double ppl_diff_unc = sqrt(x: ppl_unc*ppl_unc + ppl_base_unc*ppl_base_unc - 2.0*ppl_cov); |
| 1913 | LOG("Mean PPL(Q)-PPL(base) : %10.6lf ± %10.6lf\n" , ppl_diff_val, ppl_diff_unc); |
| 1914 | |
| 1915 | LOG("\n" ); |
| 1916 | |
| 1917 | LOG("====== KL divergence statistics ======\n" ); |
| 1918 | auto kl_div = mean_and_uncertainty(kld.sum_kld, kld.sum_kld2, kld.count); |
| 1919 | LOG("Mean KLD: %10.6lf ± %10.6lf\n" , kl_div.first, kl_div.second); |
| 1920 | auto kld_median = kld_values.size()%2 == 0 ? 0.5f*(kld_values[kld_values.size()/2] + kld_values[kld_values.size()/2-1]) |
| 1921 | : kld_values[kld_values.size()/2]; |
| 1922 | |
| 1923 | auto percentile = [] (std::vector<float> values, float fraction) { |
| 1924 | if (fraction <= 0) return values.front(); |
| 1925 | if (fraction >= 1) return values.back(); |
| 1926 | float p = fraction*(values.size() - 1); |
| 1927 | size_t ip = size_t(p); p -= ip; |
| 1928 | return (1 - p)*values[ip] + p*values[std::min(a: ip+1, b: values.size()-1)]; |
| 1929 | }; |
| 1930 | |
| 1931 | LOG("Maximum KLD: %10.6f\n" , kld_values.back()); |
| 1932 | LOG("99.9%% KLD: %10.6f\n" , percentile(kld_values, 0.999f)); |
| 1933 | LOG("99.0%% KLD: %10.6f\n" , percentile(kld_values, 0.990f)); |
| 1934 | LOG("95.0%% KLD: %10.6f\n" , percentile(kld_values, 0.950f)); |
| 1935 | LOG("90.0%% KLD: %10.6f\n" , percentile(kld_values, 0.900f)); |
| 1936 | LOG("Median KLD: %10.6f\n" , kld_median); |
| 1937 | LOG("10.0%% KLD: %10.6f\n" , percentile(kld_values, 0.100f)); |
| 1938 | LOG(" 5.0%% KLD: %10.6f\n" , percentile(kld_values, 0.050f)); |
| 1939 | LOG(" 1.0%% KLD: %10.6f\n" , percentile(kld_values, 0.010f)); |
| 1940 | LOG(" 0.1%% KLD: %10.6f\n" , percentile(kld_values, 0.001f)); |
| 1941 | LOG("Minimum KLD: %10.6f\n" , kld_values.front()); |
| 1942 | |
| 1943 | LOG("\n" ); |
| 1944 | |
| 1945 | LOG("====== Token probability statistics ======\n" ); |
| 1946 | |
| 1947 | auto p_diff = mean_and_uncertainty(kld.sum_p_diff, kld.sum_p_diff2, kld.count); |
| 1948 | LOG("Mean Δp: %6.3lf ± %5.3lf %%\n" , 100.0*p_diff.first, 100.0*p_diff.second); |
| 1949 | |
| 1950 | auto p_diff_median = p_diff_values.size()%2 == 0 ? 0.5f*(p_diff_values[p_diff_values.size()/2] + p_diff_values[p_diff_values.size()/2-1]) |
| 1951 | : p_diff_values[p_diff_values.size()/2]; |
| 1952 | |
| 1953 | LOG("Maximum Δp: %6.3lf%%\n" , 100.0*p_diff_values.back()); |
| 1954 | LOG("99.9%% Δp: %6.3lf%%\n" , 100.0*percentile(p_diff_values, 0.999f)); |
| 1955 | LOG("99.0%% Δp: %6.3lf%%\n" , 100.0*percentile(p_diff_values, 0.990f)); |
| 1956 | LOG("95.0%% Δp: %6.3lf%%\n" , 100.0*percentile(p_diff_values, 0.950f)); |
| 1957 | LOG("90.0%% Δp: %6.3lf%%\n" , 100.0*percentile(p_diff_values, 0.900f)); |
| 1958 | LOG("75.0%% Δp: %6.3lf%%\n" , 100.0*percentile(p_diff_values, 0.750f)); |
| 1959 | LOG("Median Δp: %6.3lf%%\n" , 100.0*p_diff_median); |
| 1960 | LOG("25.0%% Δp: %6.3lf%%\n" , 100.0*percentile(p_diff_values, 0.250f)); |
| 1961 | LOG("10.0%% Δp: %6.3lf%%\n" , 100.0*percentile(p_diff_values, 0.100f)); |
| 1962 | LOG(" 5.0%% Δp: %6.3lf%%\n" , 100.0*percentile(p_diff_values, 0.050f)); |
| 1963 | LOG(" 1.0%% Δp: %6.3lf%%\n" , 100.0*percentile(p_diff_values, 0.010f)); |
| 1964 | LOG(" 0.1%% Δp: %6.3lf%%\n" , 100.0*percentile(p_diff_values, 0.001f)); |
| 1965 | LOG("Minimum Δp: %6.3lf%%\n" , 100.0*p_diff_values.front()); |
| 1966 | |
| 1967 | auto p_diff_mse = mean_and_uncertainty(kld.sum_p_diff2, kld.sum_p_diff4, kld.count); |
| 1968 | // LOG("MSE Δp : %10.6lf ± %10.6lf\n", p_diff_mse.first, p_diff_mse.second); |
| 1969 | |
| 1970 | const double p_diff_rms_val = sqrt(x: p_diff_mse.first); |
| 1971 | const double p_diff_rms_unc = 0.5/p_diff_rms_val * p_diff_mse.second; |
| 1972 | LOG("RMS Δp : %6.3lf ± %5.3lf %%\n" , 100.0*p_diff_rms_val, 100.0*p_diff_rms_unc); |
| 1973 | |
| 1974 | const double same_top_p = 1.0*kld.n_same_top/kld.count; |
| 1975 | LOG("Same top p: %6.3lf ± %5.3lf %%\n" , 100.0*same_top_p, 100.0*sqrt(same_top_p*(1.0 - same_top_p)/(kld.count - 1))); |
| 1976 | } |
| 1977 | |
| 1978 | int main(int argc, char ** argv) { |
| 1979 | common_params params; |
| 1980 | |
| 1981 | params.n_ctx = 512; |
| 1982 | params.escape = false; |
| 1983 | |
| 1984 | if (!common_params_parse(argc, argv, params, ex: LLAMA_EXAMPLE_PERPLEXITY)) { |
| 1985 | return 1; |
| 1986 | } |
| 1987 | |
| 1988 | common_init(); |
| 1989 | |
| 1990 | const int32_t n_ctx = params.n_ctx; |
| 1991 | |
| 1992 | if (n_ctx <= 0) { |
| 1993 | LOG_ERR("%s: perplexity tool requires '--ctx-size' > 0\n" , __func__); |
| 1994 | return 1; |
| 1995 | } |
| 1996 | |
| 1997 | const bool ppl = !params.hellaswag && !params.winogrande && !params.multiple_choice && !params.kl_divergence; |
| 1998 | |
| 1999 | if (ppl) { |
| 2000 | const int32_t n_seq = std::max(a: 1, b: params.n_batch / n_ctx); |
| 2001 | const int32_t n_kv = n_seq * n_ctx; |
| 2002 | |
| 2003 | params.n_parallel = n_seq; |
| 2004 | params.n_ctx = n_kv; |
| 2005 | |
| 2006 | params.n_batch = std::min(a: params.n_batch, b: n_kv); |
| 2007 | } else { |
| 2008 | params.n_batch = std::min(a: params.n_batch, b: params.n_ctx); |
| 2009 | if (params.kl_divergence) { |
| 2010 | params.n_parallel = 1; |
| 2011 | } else { |
| 2012 | // ensure there's at least enough seq_ids for HellaSwag |
| 2013 | params.n_parallel = std::max(a: 4, b: params.n_parallel); |
| 2014 | } |
| 2015 | } |
| 2016 | |
| 2017 | if (params.ppl_stride > 0) { |
| 2018 | LOG_INF("Will perform strided perplexity calculation -> adjusting context size from %d to %d\n" , |
| 2019 | params.n_ctx, params.n_ctx + params.ppl_stride/2); |
| 2020 | params.n_ctx += params.ppl_stride/2; |
| 2021 | } |
| 2022 | |
| 2023 | llama_backend_init(); |
| 2024 | llama_numa_init(numa: params.numa); |
| 2025 | |
| 2026 | // load the model and apply lora adapter, if any |
| 2027 | common_init_result llama_init = common_init_from_params(params); |
| 2028 | |
| 2029 | llama_model * model = llama_init.model.get(); |
| 2030 | llama_context * ctx = llama_init.context.get(); |
| 2031 | |
| 2032 | if (model == NULL) { |
| 2033 | LOG_ERR("%s: unable to load model\n" , __func__); |
| 2034 | return 1; |
| 2035 | } |
| 2036 | |
| 2037 | const int n_ctx_train = llama_model_n_ctx_train(model); |
| 2038 | |
| 2039 | if (params.n_ctx > n_ctx_train) { |
| 2040 | LOG_WRN("%s: model was trained on only %d context tokens (%d specified)\n" , |
| 2041 | __func__, n_ctx_train, params.n_ctx); |
| 2042 | } |
| 2043 | |
| 2044 | // print system information |
| 2045 | { |
| 2046 | LOG_INF("\n" ); |
| 2047 | LOG_INF("%s\n" , common_params_get_system_info(params).c_str()); |
| 2048 | } |
| 2049 | |
| 2050 | struct results_perplexity results; |
| 2051 | if (params.hellaswag) { |
| 2052 | hellaswag_score(ctx, params); |
| 2053 | } else if (params.winogrande) { |
| 2054 | winogrande_score(ctx, params); |
| 2055 | } else if (params.multiple_choice) { |
| 2056 | multiple_choice_score(ctx, params); |
| 2057 | } else if (params.kl_divergence) { |
| 2058 | kl_divergence(ctx, params); |
| 2059 | } else { |
| 2060 | results = perplexity(ctx, params, n_ctx); |
| 2061 | } |
| 2062 | |
| 2063 | LOG("\n" ); |
| 2064 | llama_perf_context_print(ctx); |
| 2065 | llama_memory_breakdown_print(ctx); |
| 2066 | |
| 2067 | llama_backend_free(); |
| 2068 | |
| 2069 | return 0; |
| 2070 | } |
| 2071 | |