1#pragma once
2
3#include "llama.h"
4
5#include "llama-cparams.h"
6
7#include <array>
8#include <vector>
9#include <set>
10#include <bitset>
11#include <memory>
12#include <unordered_map>
13
14// keep this struct lightweight
15struct llama_ubatch {
16 bool equal_seqs() const {
17 return b_equal_seqs != 0;
18 }
19
20 // typical for M-RoPE cases:
21 // 0 - sequantial position of the tokens/embeddings in the sequence
22 // 1 - y position in the image
23 // 2 - x position in the image
24 // 3 - other
25 bool is_pos_2d() const {
26 // TODO @ngxson : we may need to check for model arch when more models use >1 positions
27 return n_pos >= 3;
28 }
29
30 uint32_t b_equal_seqs; // note: this is a boolean, but we use an int32_t for alignment
31 // otherwise address sanitizer complains
32 // TODO: whole_seqs for embeddings?
33
34 uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs)
35 uint32_t n_seq_tokens; // tokens per sequence set
36 uint32_t n_seqs; // sequence sets in the ubatch
37 uint32_t n_seqs_unq; // unique sequence ids in the ubatch
38 uint32_t n_pos; // number of position inputs for each token/embedding
39
40 // seq_id_unq: unique sequence ids in the ubatch
41 // seq_idx: indices of the unique sequence ids in the ubatch in [0, n_seqs_unq)
42 // used for extracting sequence pooled embeddings
43
44 // // size | idx | val
45 llama_token * token; // [n_tokens] | i | id, token
46 float * embd; // [n_embd, n_tokens] | i | embd
47 llama_pos * pos; // [n_tokens*n_pos] | i | pos
48 int32_t * n_seq_id; // [n_tokens] | i | -
49 llama_seq_id ** seq_id; // [n_tokens] | s | s0, s1, seq_id
50 llama_seq_id * seq_id_unq; // [n_seqs_unq] | s | seq_id
51 int32_t * seq_idx; // [LLAMA_MAX_SEQ] | - | seq_idx
52 int8_t * output; // [n_tokens] | i | -
53
54 struct data_t {
55 std::vector<llama_token> token;
56 std::vector<float> embd;
57 std::vector<llama_pos> pos;
58 std::vector<int32_t> n_seq_id;
59 std::vector<llama_seq_id *> seq_id;
60 std::vector<llama_seq_id> seq_id_unq;
61 std::vector<int32_t> seq_idx;
62 std::vector<int8_t> output;
63 };
64
65 // the llama_ubatch pointers above point to this data if set. otherwise - points to non-owning data
66 std::shared_ptr<data_t> data;
67};
68
69// a helper for sanitizing, fulfilling and splitting a batch
70class llama_batch_allocr {
71public:
72 llama_batch_allocr(uint32_t n_pos_per_embd);
73
74 // sanitize and auto-gen missing data in the input batch
75 // memory is optional. if provided will be used to check for sequence continuity and to determine the positions
76 bool init(
77 const llama_batch & batch_inp,
78 const llama_vocab & vocab,
79 const llama_memory_i * memory,
80 uint32_t n_embd,
81 uint32_t n_seq_max,
82 bool output_all);
83
84 const llama_batch & get_batch() const;
85
86 uint32_t get_n_tokens() const;
87 uint32_t get_n_outputs() const;
88 uint32_t get_n_used() const;
89
90 // the array of output indices in the order they were encountered during the ubatch splitting
91 std::vector<int32_t> & get_out_ids();
92
93 // min/max positions of each sequence in the current ubatch
94 llama_pos seq_pos_min(llama_seq_id seq_id) const;
95 llama_pos seq_pos_max(llama_seq_id seq_id) const;
96
97 // call once before splitting the batch to reset the internal state
98 void split_reset();
99
100 // simple split, unknown number of sequence sets of unequal lengths
101 llama_ubatch split_simple(uint32_t n_ubatch);
102
103 // make ubatches of equal-length sequences sets
104 // if sequential == true, the tokens in the ubatch will have increasing sequential sequence ids
105 llama_ubatch split_equal(uint32_t n_ubatch, bool sequential);
106
107 // sequence-set-wise split - each ubatch contains a single sequence-set
108 llama_ubatch split_seq(uint32_t n_ubatch);
109
110 // a helper method for creating a well-defined ubatch of tokens
111 // TODO: support embeddings if needed in the future
112 llama_ubatch ubatch_reserve(uint32_t n_seq_tokens, uint32_t n_seqs);
113
114private:
115 void clear();
116
117 // create the next ubatch based on the provided batch indices (idxs) and the number of sequence sets (n_seqs)
118 // return llama_ubatch.n_tokens == 0 if the entire batch was consumed
119 llama_ubatch ubatch_add(const std::vector<int32_t> & idxs, uint32_t n_seqs, bool equal_seqs);
120
121 // for debugging, start with LLAMA_BATCH_DEBUG=2
122 void ubatch_print(const llama_ubatch & ubatch, int debug);
123
124 llama_batch batch;
125
126 // only for debugging purposes
127 const llama_vocab * vocab;
128
129 // TODO: this is more of a temporary solution until we have a better way to handle multiple positions per token/embd
130 // ref: https://github.com/ggml-org/llama.cpp/issues/13694#issuecomment-2983871762
131 const uint32_t n_pos_per_embd;
132
133 uint32_t n_embd;
134 uint32_t n_seq_max;
135 uint32_t n_outputs;
136
137 std::array<llama_seq_id, 1> seq_id_0 = {._M_elems: { 0 }}; // default sequence id
138
139 std::vector<llama_pos> pos;
140 std::vector<int32_t> n_seq_id;
141 std::vector<llama_seq_id *> seq_id;
142 std::vector<llama_seq_id> seq_id_unq;
143 std::vector<int32_t> seq_idx;
144 std::vector<int8_t> output;
145
146 using pos_set_t = std::set<llama_pos>;
147 using seq_cpl_t = std::vector<bool>;
148
149 // helper flag to quickly determine if there are any coupled sequences in the batch
150 bool has_cpl = false;
151
152 std::vector<pos_set_t> seq_pos; // seq_pos[s]: the set of positions in sequence s
153 std::vector<seq_cpl_t> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1
154
155 using idx_vec_t = std::vector<int32_t>;
156 using seq_set_t = std::bitset<LLAMA_MAX_SEQ>;
157
158 std::vector<seq_set_t> seq_set; // seq_set[i]: the sequence set of token i
159
160 std::unordered_map<seq_set_t, idx_vec_t> seq_set_map; // the indices at which the sequence set appears
161
162 // batch indices of the output
163 std::vector<int32_t> out_ids;
164
165 uint32_t n_used;
166
167 // used[i] indicates if token i has already been used in a previous ubatch
168 std::vector<bool> used;
169
170 int debug;
171};
172