1#include "llama-batch.h"
2
3#include "llama-impl.h"
4#include "llama-vocab.h"
5#include "llama-memory.h"
6
7#include <cassert>
8#include <cstring>
9#include <algorithm>
10#include <sstream>
11
12llama_batch_allocr::llama_batch_allocr(uint32_t n_pos_per_embd) : n_pos_per_embd(n_pos_per_embd) {
13 const char * LLAMA_BATCH_DEBUG = getenv(name: "LLAMA_BATCH_DEBUG");
14 debug = LLAMA_BATCH_DEBUG ? atoi(nptr: LLAMA_BATCH_DEBUG) : 0;
15
16 seq_pos.resize(LLAMA_MAX_SEQ);
17 seq_cpl.resize(LLAMA_MAX_SEQ);
18 for (auto & cur : seq_cpl) {
19 cur.resize(LLAMA_MAX_SEQ);
20 }
21
22 seq_idx.resize(LLAMA_MAX_SEQ, x: -1);
23}
24
25bool llama_batch_allocr::init(
26 const llama_batch & batch_inp,
27 const llama_vocab & vocab,
28 const llama_memory_i * memory,
29 uint32_t n_embd,
30 uint32_t n_seq_max,
31 bool output_all) {
32 clear();
33
34 batch = batch_inp;
35
36 this->vocab = &vocab;
37
38 GGML_ASSERT(batch.n_tokens > 0);
39
40 //
41 // validate input batch
42 //
43
44 if (n_seq_max > LLAMA_MAX_SEQ) {
45 LLAMA_LOG_ERROR("%s: n_seq_max = %d > %d\n", __func__, n_seq_max, LLAMA_MAX_SEQ);
46 return false;
47 }
48
49 if (batch.token) {
50 for (int32_t i = 0; i < batch.n_tokens; ++i) {
51 if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= vocab.n_tokens()) {
52 LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
53 return false;
54 }
55 }
56 }
57
58 if (batch.seq_id) {
59 for (int32_t i = 0; i < batch.n_tokens; ++i) {
60 for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
61 if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= (llama_seq_id) n_seq_max)) {
62 LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d >= %d\n", __func__, i, s, batch.seq_id[i][s], (llama_seq_id) n_seq_max);
63 return false;
64 }
65 }
66 }
67 }
68
69 //
70 // auto-generate missing fields
71 //
72
73 if (!batch.n_seq_id) {
74 n_seq_id.resize(new_size: batch.n_tokens);
75 for (int32_t i = 0; i < batch.n_tokens; i++) {
76 n_seq_id[i] = seq_id_0.size();
77 }
78 batch.n_seq_id = n_seq_id.data();
79 }
80
81 if (!batch.seq_id) {
82 seq_id.resize(new_size: batch.n_tokens + 1);
83 seq_id[batch.n_tokens] = NULL;
84 for (int32_t i = 0; i < batch.n_tokens; i++) {
85 seq_id[i] = seq_id_0.data();
86 }
87 batch.seq_id = seq_id.data();
88 }
89
90 if (!batch.pos) {
91 pos.resize(new_size: batch.n_tokens);
92
93 // initialize the starting position for each sequence based on the positions in the memory
94 llama_pos p0[LLAMA_MAX_SEQ];
95 for (uint32_t s = 0; s < n_seq_max; ++s) {
96 if (!memory) {
97 // if no memory -> start from 0
98 p0[s] = 0;
99 } else {
100 p0[s] = memory->seq_pos_max(seq_id: s) + 1;
101 }
102 }
103
104 for (int32_t i = 0; i < batch.n_tokens; i++) {
105 const llama_seq_id seq_id = batch.seq_id[i][0];
106
107 pos[i] = p0[seq_id];
108
109 // update the starting position for all sequences that are assigned to the this token
110 for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
111 const llama_seq_id seq_id = batch.seq_id[i][s];
112
113 p0[seq_id] = pos[i] + 1;
114 }
115 }
116
117 batch.pos = pos.data();
118 }
119
120 if (!batch.logits) {
121 if (output_all) {
122 // return the output for all tokens
123 output.resize(new_size: batch.n_tokens, x: true);
124 } else {
125 // return the output only for the last token
126 output.resize(new_size: batch.n_tokens, x: false);
127 output[output.size() - 1] = true;
128 }
129
130 batch.logits = output.data();
131 } else if (output_all) {
132 bool warn = false;
133
134 for (int32_t i = 0; i < batch.n_tokens; ++i) {
135 if (batch.logits[i] == 0) {
136 warn = true;
137 }
138 }
139
140 if (warn) {
141 LLAMA_LOG_WARN("%s: embeddings required but some input tokens were not marked as outputs -> overriding\n", __func__);
142
143 output.resize(new_size: batch.n_tokens, x: true);
144 batch.logits = output.data();
145 }
146 }
147
148 //
149 // compute stats
150 //
151
152 this->n_embd = n_embd;
153 this->n_seq_max = n_seq_max;
154
155 // count the outputs in this batch
156 for (int32_t i = 0; i < batch.n_tokens; ++i) {
157 n_outputs += batch.logits[i] != 0;
158 }
159
160 has_cpl = false;
161
162 // determine coupled sequences
163 // these are pairs of sequences that have at least one token in the input batch that is assigned to both of them
164 for (int32_t i = 0; i < batch.n_tokens; ++i) {
165 const llama_seq_id s0 = batch.seq_id[i][0];
166
167 for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
168 const llama_seq_id s1 = batch.seq_id[i][s];
169
170 seq_pos[s1].insert(x: batch.pos[i]);
171
172 if (s > 0) {
173 // mark that sequence s1 is coupled to s0
174 seq_cpl[s1][s0] = true;
175
176 // note: tracking the other way around is not necessary for now
177 //seq_cpl[s0][s1] = true;
178
179 has_cpl = true;
180 }
181 }
182 }
183
184 // precompute the sequence sets for each token and determine the unique sequence ids that participate in the batch
185 {
186 seq_set_t seq_set_unq;
187
188 for (int32_t i = 0; i < batch.n_tokens; ++i) {
189 seq_set_t cur;
190 for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
191 const llama_seq_id seq_id = batch.seq_id[i][s];
192
193 cur .set(position: seq_id);
194 seq_set_unq.set(position: seq_id);
195 }
196
197 seq_set.push_back(x: cur);
198 seq_set_map[cur].push_back(x: i);
199 }
200
201 for (uint32_t s = 0; s < n_seq_max; ++s) {
202 if (seq_set_unq.test(position: s)) {
203 seq_idx[s] = seq_id_unq.size();
204 seq_id_unq.push_back(x: s);
205 }
206 }
207 }
208
209 if (debug > 0) {
210 LLAMA_LOG_DEBUG("%s: input batch info:\n", __func__);
211
212 llama_ubatch ubatch {
213 /*.b_equal_seqs =*/ false,
214 /*.n_tokens =*/ (uint32_t) batch.n_tokens,
215 /*.n_seq_tokens =*/ (uint32_t) 1,
216 /*.n_seqs =*/ (uint32_t) batch.n_tokens,
217 /*.n_seqs_unq =*/ (uint32_t) this->seq_id_unq.size(),
218 /*.n_pos =*/ n_pos_per_embd,
219 /*.token =*/ batch.token,
220 /*.embd =*/ batch.embd,
221 /*.pos =*/ batch.pos,
222 /*.n_seq_id =*/ batch.n_seq_id,
223 /*.seq_id =*/ batch.seq_id,
224 /*.seq_id_unq =*/ this->seq_id_unq.data(),
225 /*.seq_idx =*/ this->seq_idx.data(),
226 /*.output =*/ batch.logits,
227 /*.data =*/ {},
228 };
229
230 ubatch_print(ubatch, debug);
231
232 LLAMA_LOG_DEBUG("%s: seq = [\n", __func__);
233 for (int s0 = 0; s0 < (int) seq_pos.size(); ++s0) {
234 if (seq_pos[s0].empty()) {
235 continue;
236 }
237
238 std::stringstream ss;
239 for (int s1 = 0; s1 < (int) seq_cpl[s0].size(); ++s1) {
240 if (seq_cpl[s0][s1]) {
241 ss << s1 << " ";
242 }
243 }
244
245 LLAMA_LOG_DEBUG("%s: %4d: pos = [%4d, %4d], cpl = %s\n",
246 __func__, s0, seq_pos_min(s0), seq_pos_max(s0), ss.str().empty() ? "-" : ss.str().c_str());
247 }
248 LLAMA_LOG_DEBUG("%s: ]\n", __func__);
249 }
250
251 //
252 // consistency checks
253 //
254
255 if (n_pos_per_embd > 1) {
256 // M-RoPE case: allow position to "jump" forward only (non-continuous positions are allowed)
257 for (uint32_t s = 0; s < n_seq_max; ++s) {
258 if (seq_pos[s].empty()) {
259 continue;
260 }
261
262 const llama_pos p0 = memory ? memory->seq_pos_max(seq_id: s) : -1;
263
264 if (batch.token) {
265 if (p0 >= 0 && p0 >= seq_pos_min(seq_id: s)) {
266 LLAMA_LOG_ERROR(
267 "%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
268 " - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
269 " - the tokens for sequence %d in the input batch have a starting position of Y = %d\n"
270 " for M-RoPE, it is required that the position satisfies: X < Y\n",
271 __func__, s, s, p0, s, seq_pos_min(s));
272
273 return false;
274 }
275 } else {
276 // embedding inputs can have overlapping positions
277 if (p0 >= 0 && p0 > seq_pos_min(seq_id: s)) {
278 LLAMA_LOG_ERROR(
279 "%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
280 " - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
281 " - the tokens for sequence %d in the input batch have a starting position of Y = %d\n"
282 " for M-RoPE, it is required that the position satisfies: X <= Y\n",
283 __func__, s, s, p0, s, seq_pos_min(s));
284
285 return false;
286 }
287 }
288 }
289 } else {
290 for (uint32_t s = 0; s < n_seq_max; ++s) {
291 if (seq_pos[s].empty()) {
292 continue;
293 }
294
295 const llama_pos p0 = memory ? memory->seq_pos_max(seq_id: s) : -1;
296
297 if (p0 >= 0) {
298 bool ok = true;
299
300 if (seq_pos_min(seq_id: s) != p0 + 1) {
301 ok = false;
302 }
303
304 if (!ok) {
305 LLAMA_LOG_ERROR(
306 "%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
307 " - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
308 " - the tokens for sequence %d in the input batch have a starting position of Y = %d\n"
309 " it is required that the sequence positions remain consecutive: Y = X + 1\n",
310 __func__, s, s, p0, s, seq_pos_min(s));
311
312 return false;
313 }
314 }
315
316 if (seq_pos_max(seq_id: s) - seq_pos_min(seq_id: s) + 1 > (int) seq_pos[s].size()) {
317 LLAMA_LOG_ERROR("%s: sequence %d positions are not continuous\n", __func__, s);
318 return false;
319 }
320 }
321 }
322
323 if (memory) {
324 for (uint32_t s0 = 0; s0 < n_seq_max; ++s0) {
325 for (uint32_t s1 = 0; s1 < n_seq_max; ++s1) {
326 if (seq_cpl[s0][s1]) {
327 if (memory->seq_pos_min(seq_id: s0) != memory->seq_pos_min(seq_id: s1) ||
328 memory->seq_pos_max(seq_id: s0) != memory->seq_pos_max(seq_id: s1)) {
329 LLAMA_LOG_ERROR("%s: sequence %d is coupled to %d in the input batch, but have divereged\n", __func__, s0, s1);
330 return false;
331 }
332 }
333 }
334 }
335 }
336
337 // disallow partial sequence sub-sets:
338 //
339 // invalid: x
340 // i: 0 1 2 ...
341 // ---------------------------------------
342 // seq_id[i][0]: 0 0 1
343 // seq_id[i][1]: 1 1 2
344 // seq_id[i][2]: 2
345 //
346 // disallow decreasing sequence positions:
347 //
348 // invalid: x
349 // i: 0 1 2 3 4 5 6 ...
350 // ---------------------------------------
351 // pos[i]: 4 5 0 1 6 2 3
352 // seq_id[i][0]: 0 0 1 1 0 1 0
353 //
354 {
355 seq_set_t cur_seq_set[LLAMA_MAX_SEQ];
356 for (uint32_t s = 0; s < n_seq_max; ++s) {
357 cur_seq_set[s].set();
358 }
359
360 llama_pos cur_seq_pos[LLAMA_MAX_SEQ];
361 for (uint32_t s = 0; s < n_seq_max; ++s) {
362 cur_seq_pos[s] = -1;
363 }
364
365 for (int32_t i = 0; i < batch.n_tokens; ++i) {
366 const llama_pos pos = batch.pos[i];
367
368 for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
369 const llama_seq_id seq_id = batch.seq_id[i][s];
370
371 cur_seq_set[seq_id] &= seq_set[i];
372
373 if (cur_seq_set[seq_id].none()) {
374 LLAMA_LOG_ERROR("%s: sequence %d belongs to incompatible sequence sets (not allowed)\n", __func__, seq_id);
375 return false;
376 }
377
378 if (pos < cur_seq_pos[seq_id]) {
379 LLAMA_LOG_ERROR("%s: sequence %d positions are decreasing (not allowed)\n", __func__, seq_id);
380 return false;
381 }
382 }
383 }
384 }
385
386 split_reset();
387
388 return true;
389}
390
391llama_ubatch llama_batch_allocr::ubatch_reserve(uint32_t n_seq_tokens, uint32_t n_seqs) {
392 const uint32_t n_tokens = n_seq_tokens*n_seqs;
393
394 clear();
395 split_reset();
396
397 auto udata = std::make_shared<llama_ubatch::data_t>();
398
399 udata->token .resize(new_size: n_tokens);
400 udata->embd .clear();
401 udata->pos .resize(new_size: n_tokens);
402 udata->n_seq_id .resize(new_size: n_tokens);
403 udata->seq_id .resize(new_size: n_tokens);
404 udata->seq_id_unq.resize(new_size: 0);
405 udata->seq_idx .resize(LLAMA_MAX_SEQ, x: -1);
406 udata->output .resize(new_size: n_tokens);
407
408 for (uint32_t s = 0; s < n_seqs; ++s) {
409 udata->seq_idx[s] = s;
410 udata->seq_id_unq.push_back(x: s);
411 }
412
413 llama_ubatch res {
414 /*.b_equal_seqs =*/ true,
415 /*.n_tokens =*/ n_tokens,
416 /*.n_seq_tokens =*/ n_seq_tokens,
417 /*.n_seqs =*/ n_seqs,
418 /*.n_seqs_unq =*/ n_seqs,
419 /*.n_pos =*/ n_pos_per_embd,
420
421 /*.token =*/ udata->token.data(),
422 /*.embd =*/ nullptr,
423 /*.pos =*/ udata->pos.data(),
424 /*.n_seq_id =*/ udata->n_seq_id.data(),
425 /*.seq_id =*/ udata->seq_id.data(),
426 /*.seq_id_unq =*/ udata->seq_id_unq.data(),
427 /*.seq_idx =*/ udata->seq_idx.data(),
428 /*.output =*/ udata->output.data(),
429 /*.data =*/ std::move(udata),
430 };
431
432 return res;
433}
434
435const llama_batch & llama_batch_allocr::get_batch() const {
436 return batch;
437}
438
439uint32_t llama_batch_allocr::get_n_tokens() const {
440 return batch.n_tokens;
441}
442
443uint32_t llama_batch_allocr::get_n_outputs() const {
444 return n_outputs;
445}
446
447uint32_t llama_batch_allocr::get_n_used() const {
448 return n_used;
449}
450
451std::vector<int32_t> & llama_batch_allocr::get_out_ids() {
452 return out_ids;
453}
454
455llama_pos llama_batch_allocr::seq_pos_min(llama_seq_id seq_id) const {
456 return seq_pos[seq_id].empty() ? -1 : *seq_pos[seq_id].begin();
457}
458
459llama_pos llama_batch_allocr::seq_pos_max(llama_seq_id seq_id) const {
460 return seq_pos[seq_id].empty() ? -1 : *seq_pos[seq_id].rbegin();
461}
462
463void llama_batch_allocr::split_reset() {
464 out_ids.clear();
465
466 n_used = 0;
467
468 used.clear();
469 used.resize(new_size: get_n_tokens(), x: false);
470}
471
472llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
473 // find the first unused token
474 uint32_t cur_idx = 0;
475 while (cur_idx < used.size() && used[cur_idx]) {
476 ++cur_idx;
477 }
478
479 // we are done
480 if (cur_idx >= used.size()) {
481 return {};
482 }
483
484 std::vector<int32_t> idxs;
485
486 while (true) {
487 idxs.push_back(x: cur_idx);
488
489 used[cur_idx] = true;
490 ++n_used;
491
492 ++cur_idx;
493
494 if (cur_idx >= used.size()) {
495 break;
496 }
497
498 if (idxs.size() >= n_ubatch) {
499 break;
500 }
501 }
502
503 return ubatch_add(idxs, n_seqs: idxs.size(), equal_seqs: false);
504}
505
506llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch, bool sequential) {
507 if (sequential && has_cpl) {
508 LLAMA_LOG_ERROR("%s: sequential split is not supported when there are coupled sequences in the input batch (you may need to use the -kvu flag)\n", __func__);
509
510 return {};
511 }
512
513 std::vector<seq_set_t> cur_seq_set;
514
515 llama_seq_id last_seq_id = -1;
516
517 // determine the non-overlapping sequence sets participating in this ubatch
518 for (int32_t i = 0; i < batch.n_tokens; ++i) {
519 if (used[i]) {
520 continue;
521 }
522
523 bool add = true;
524
525 for (uint32_t s = 0; s < cur_seq_set.size(); ++s) {
526 // no overlap with existing sequence sets:
527 if (!(cur_seq_set[s] & seq_set[i]).none()) {
528 add = false;
529 break;
530 }
531 }
532
533 // accept only increasing sequence ids
534 if (sequential) {
535 add = add && (cur_seq_set.empty() || batch.seq_id[i][0] == last_seq_id + 1);
536 }
537
538 if (add) {
539 cur_seq_set.push_back(x: seq_set[i]);
540
541 last_seq_id = batch.seq_id[i][0];
542
543 if (cur_seq_set.size() > n_ubatch) {
544 break;
545 }
546 }
547 }
548
549 const uint32_t n_seqs = cur_seq_set.size();
550
551 // we are done
552 if (n_seqs == 0) {
553 return {};
554 }
555
556 // the current batch index of each sequence set
557 std::vector<int32_t> cur_idx(n_seqs, 0);
558
559 for (uint32_t s = 0; s < n_seqs; ++s) {
560 while (used[seq_set_map[cur_seq_set[s]][cur_idx[s]]]) {
561 ++cur_idx[s];
562 }
563 }
564
565 // the list of batch indices for each sequence set
566 // at the end we will concat these to get the final ubatch
567 std::vector<idx_vec_t> idxs_per_seq(n_seqs);
568
569 while (true) {
570 // we can only add new n_seq_tokens tokens if all the sequence sets have at least one more unused token and
571 // if we haven't reached n_ubatch
572 bool can_expand = true;
573
574 for (uint32_t s = 0; s < n_seqs; ++s) {
575 if (cur_idx[s] >= (int32_t) seq_set_map[cur_seq_set[s]].size()) {
576 can_expand = false;
577 break;
578 }
579 }
580
581 if (!can_expand) {
582 break;
583 }
584
585 for (uint32_t s = 0; s < n_seqs; ++s) {
586 const int32_t idx = seq_set_map[cur_seq_set[s]][cur_idx[s]];
587
588 idxs_per_seq[s].push_back(x: idx);
589
590 used[idx] = true;
591 ++n_used;
592
593 ++cur_idx[s];
594 }
595
596 if ((idxs_per_seq[0].size() + 1)*n_seqs > n_ubatch) {
597 break;
598 }
599 }
600
601 // concat the per-sequence-set lists
602 std::vector<int32_t> idxs;
603
604 for (uint32_t s = 0; s < n_seqs; ++s) {
605 idxs.insert(position: idxs.end(), first: idxs_per_seq[s].begin(), last: idxs_per_seq[s].end());
606 }
607
608 return ubatch_add(idxs, n_seqs, equal_seqs: true);
609}
610
611llama_ubatch llama_batch_allocr::split_seq(uint32_t n_ubatch) {
612 // find the first unused token
613 uint32_t cur_idx = 0;
614 while (cur_idx < used.size() && used[cur_idx]) {
615 ++cur_idx;
616 }
617
618 // we are done
619 if (cur_idx >= used.size()) {
620 return {};
621 }
622
623 // this is the starting sequence set
624 // we allow adding tokens only if their sequence set is a subset of the current sequence set
625 auto cur_seq_set = seq_set[cur_idx];
626
627 std::vector<int32_t> idxs;
628
629 while (true) {
630 idxs.push_back(x: cur_idx);
631
632 used[cur_idx] = true;
633 ++n_used;
634
635 if (idxs.size() >= n_ubatch) {
636 break;
637 }
638
639 do {
640 ++cur_idx;
641 } while (cur_idx < get_n_tokens() && (used[cur_idx] || ((cur_seq_set & seq_set[cur_idx]) != seq_set[cur_idx])));
642
643 if (cur_idx == get_n_tokens()) {
644 break;
645 }
646
647 cur_seq_set = seq_set[cur_idx];
648 }
649
650 return ubatch_add(idxs, n_seqs: 1, equal_seqs: true);
651}
652
653void llama_batch_allocr::clear() {
654 n_outputs = 0;
655
656 batch = {};
657
658 pos .clear();
659 n_seq_id .clear();
660 seq_id .clear();
661 seq_id_unq.clear();
662 output .clear();
663
664 for (auto & cur : seq_pos) {
665 cur.clear();
666 }
667
668 for (auto & cur : seq_cpl) {
669 std::fill(first: cur.begin(), last: cur.end(), value: false);
670 }
671
672 seq_set.clear();
673
674 seq_set_map.clear();
675
676 std::fill(first: seq_idx.begin(), last: seq_idx.end(), value: -1);
677}
678
679llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, uint32_t n_seqs, bool equal_seqs) {
680 const uint32_t n_tokens = idxs.size();
681
682 assert(n_tokens%n_seqs == 0);
683
684 auto udata = std::make_shared<llama_ubatch::data_t>();
685
686 const int64_t n_embd_all = batch.embd ? (int64_t) n_tokens*n_embd : 0;
687 const int64_t n_pos_all = (int64_t) n_tokens*n_pos_per_embd;
688
689 udata->token .resize(new_size: n_tokens);
690 udata->embd .resize(new_size: n_embd_all);
691 udata->pos .resize(new_size: n_pos_all);
692 udata->n_seq_id .resize(new_size: n_tokens);
693 udata->seq_id .resize(new_size: n_tokens);
694 udata->seq_id_unq.resize(new_size: 0);
695 udata->seq_idx .resize(LLAMA_MAX_SEQ, x: -1);
696 udata->output .resize(new_size: n_tokens);
697
698 seq_set_t seq_set_unq;
699
700 for (size_t i = 0; i < idxs.size(); ++i) {
701 if (batch.token) {
702 udata->token[i] = batch.token[idxs[i]];
703 }
704
705 if (batch.embd) {
706 memcpy(dest: udata->embd.data() + i*n_embd, src: batch.embd + (int64_t) idxs[i]*n_embd, n: n_embd*sizeof(float));
707 }
708
709 for (size_t j = 0; j < (size_t)n_pos_per_embd; ++j) {
710 // if we are using M-RoPE
711 // if the current batch is text, we need to broadcast the same position across all RoPE sections
712 // otherwise, the input batch is image embeddings, we copy the positions as-is
713 // if we are not using M-RoPE, there is only one position per token (this loop runs only once)
714 size_t src_off = batch.token ? 0 : j*batch.n_tokens;
715 udata->pos[j*n_tokens + i] = batch.pos[src_off + idxs[i]];
716 }
717
718 udata->n_seq_id[i] = batch.n_seq_id[idxs[i]];
719 udata->seq_id[i] = batch.seq_id[idxs[i]];
720 udata->output[i] = batch.logits[idxs[i]];
721
722 for (int s = 0; s < udata->n_seq_id[i]; ++s) {
723 seq_set_unq.set(position: udata->seq_id[i][s]);
724 }
725
726 if (udata->output[i]) {
727 out_ids.push_back(x: idxs[i]);
728 }
729 }
730
731 for (uint32_t s = 0; s < n_seq_max; ++s) {
732 if (seq_set_unq.test(position: s)) {
733 udata->seq_idx[s] = udata->seq_id_unq.size();
734 udata->seq_id_unq.push_back(x: s);
735 }
736 }
737
738 llama_ubatch res {
739 /*.b_equal_seqs =*/ equal_seqs,
740 /*.n_tokens =*/ n_tokens,
741 /*.n_seq_tokens =*/ n_tokens/n_seqs,
742 /*.n_seqs =*/ n_seqs,
743 /*.n_seqs_unq =*/ (uint32_t) udata->seq_id_unq.size(),
744 /*.n_pos =*/ n_pos_per_embd,
745
746 /*.token =*/ batch.token ? udata->token.data() : nullptr,
747 /*.embd =*/ batch.embd ? udata->embd.data() : nullptr,
748 /*.pos =*/ udata->pos.data(),
749 /*.n_seq_id =*/ udata->n_seq_id.data(),
750 /*.seq_id =*/ udata->seq_id.data(),
751 /*.seq_id_unq =*/ udata->seq_id_unq.data(),
752 /*.seq_idx =*/ udata->seq_idx.data(),
753 /*.output =*/ udata->output.data(),
754 /*.data =*/ std::move(udata),
755 };
756
757 if (debug > 0) {
758 LLAMA_LOG_DEBUG("%s: added ubatch to split:\n", __func__);
759
760 ubatch_print(ubatch: res, debug);
761 }
762
763 return res;
764}
765
766void llama_batch_allocr::ubatch_print(const llama_ubatch & ubatch, int debug) {
767 if (debug > 0) {
768 LLAMA_LOG_DEBUG("%s: equal_seqs = %d\n", __func__, ubatch.equal_seqs());
769 LLAMA_LOG_DEBUG("%s: n_tokens = %d\n", __func__, ubatch.n_tokens);
770 LLAMA_LOG_DEBUG("%s: n_seq_tokens = %d\n", __func__, ubatch.n_seq_tokens);
771 LLAMA_LOG_DEBUG("%s: n_seqs = %d\n", __func__, ubatch.n_seqs);
772 LLAMA_LOG_DEBUG("%s: n_seqs_unq = %d\n", __func__, ubatch.n_seqs_unq);
773
774 std::stringstream ss_seq_id_unq;
775 std::stringstream ss_seq_idx;
776
777 ss_seq_id_unq << "[ ";
778 ss_seq_idx << "[";
779
780 for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
781 ss_seq_id_unq << ubatch.seq_id_unq[s] << " ";
782 }
783
784 for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
785 if (ubatch.seq_idx[s] >= 0) {
786 ss_seq_idx << ubatch.seq_idx[s]%10;
787 } else {
788 ss_seq_idx << ".";
789 }
790 }
791
792 ss_seq_id_unq << "]";
793 ss_seq_idx << "]";
794
795 LLAMA_LOG_DEBUG("%s: token = %p\n", __func__, (void *) ubatch.token);
796 LLAMA_LOG_DEBUG("%s: embd = %p\n", __func__, (void *) ubatch.embd);
797 LLAMA_LOG_DEBUG("%s: pos = %p\n", __func__, (void *) ubatch.pos);
798 LLAMA_LOG_DEBUG("%s: n_seq_id = %p\n", __func__, (void *) ubatch.n_seq_id);
799 LLAMA_LOG_DEBUG("%s: seq_id = %p\n", __func__, (void *) ubatch.seq_id);
800 LLAMA_LOG_DEBUG("%s: seq_id_unq = %s\n", __func__, ss_seq_id_unq.str().c_str());
801 LLAMA_LOG_DEBUG("%s: seq_idx = %s\n", __func__, ss_seq_idx.str().c_str());
802 LLAMA_LOG_DEBUG("%s: output = %p\n", __func__, (void *) ubatch.output);
803 LLAMA_LOG_DEBUG("%s: n_outputs = %d\n", __func__, n_outputs);
804
805 if (debug > 1) {
806 int seq_id_max = 0;
807 for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
808 for (int s = 0; s < ubatch.n_seq_id[i]; ++s) {
809 for (int s = 0; s < ubatch.n_seq_id[i]; ++s) {
810 seq_id_max = std::max(a: seq_id_max, b: ubatch.seq_id[i][s]);
811 }
812 }
813 }
814 ++seq_id_max;
815
816 LLAMA_LOG_DEBUG("%s: token = [\n", __func__);
817 for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
818 std::vector<int8_t> seq_id(seq_id_max);
819
820 for (int s = 0; s < ubatch.n_seq_id[i]; ++s) {
821 seq_id[ubatch.seq_id[i][s]] = 1;
822 }
823
824 std::stringstream ss;
825 for (int s = 0; s < seq_id_max; ++s) {
826 if (seq_id[s]) {
827 ss << s%10;
828 } else {
829 ss << ".";
830 }
831 }
832
833 if (ubatch.token) {
834 LLAMA_LOG_DEBUG("%s: %4d: id = %6d (%16s), pos = %4d, n_seq_id = %2d, seq_id = [%s], output = %d\n",
835 __func__, i, ubatch.token[i], vocab->token_to_piece(ubatch.token[i]).c_str(),
836 ubatch.pos[i], ubatch.n_seq_id[i], ss.str().c_str(), ubatch.output[i]);
837 } else {
838 LLAMA_LOG_DEBUG("%s: %4d: [embd], pos = %4d, n_seq_id = %2d, seq_id = [%s], output = %d\n",
839 __func__, i, ubatch.pos[i], ubatch.n_seq_id[i], ss.str().c_str(), ubatch.output[i]);
840 }
841 }
842 LLAMA_LOG_DEBUG("%s: ]\n", __func__);
843 }
844 }
845}
846
847//
848// interface implementation
849//
850
851struct llama_batch llama_batch_get_one(
852 llama_token * tokens,
853 int32_t n_tokens) {
854 return {
855 /*n_tokens =*/ n_tokens,
856 /*tokens =*/ .token: tokens,
857 /*embd =*/ nullptr,
858 /*pos =*/ nullptr,
859 /*n_seq_id =*/ nullptr,
860 /*seq_id =*/ nullptr,
861 /*logits =*/ nullptr,
862 };
863}
864
865struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) {
866 llama_batch batch = {
867 /*n_tokens =*/ 0,
868 /*tokens =*/ .token: nullptr,
869 /*embd =*/ nullptr,
870 /*pos =*/ nullptr,
871 /*n_seq_id =*/ nullptr,
872 /*seq_id =*/ nullptr,
873 /*logits =*/ nullptr,
874 };
875
876 if (embd) {
877 batch.embd = (float *) malloc(size: sizeof(float) * n_tokens_alloc * embd);
878 } else {
879 batch.token = (llama_token *) malloc(size: sizeof(llama_token) * n_tokens_alloc);
880 }
881
882 batch.pos = (llama_pos *) malloc(size: sizeof(llama_pos) * n_tokens_alloc);
883 batch.n_seq_id = (int32_t *) malloc(size: sizeof(int32_t) * n_tokens_alloc);
884 batch.seq_id = (llama_seq_id **) malloc(size: sizeof(llama_seq_id *) * (n_tokens_alloc + 1));
885 for (int i = 0; i < n_tokens_alloc; ++i) {
886 batch.seq_id[i] = (llama_seq_id *) malloc(size: sizeof(llama_seq_id) * n_seq_max);
887 }
888 batch.seq_id[n_tokens_alloc] = nullptr;
889
890 batch.logits = (int8_t *) malloc(size: sizeof(int8_t) * n_tokens_alloc);
891
892 return batch;
893}
894
895void llama_batch_free(struct llama_batch batch) {
896 if (batch.token) free(ptr: batch.token);
897 if (batch.embd) free(ptr: batch.embd);
898 if (batch.pos) free(ptr: batch.pos);
899 if (batch.n_seq_id) free(ptr: batch.n_seq_id);
900 if (batch.seq_id) {
901 for (int i = 0; batch.seq_id[i] != nullptr; ++i) {
902 free(ptr: batch.seq_id[i]);
903 }
904 free(ptr: batch.seq_id);
905 }
906 if (batch.logits) free(ptr: batch.logits);
907}
908