| 1 | #pragma once |
| 2 | |
| 3 | #include "llama.h" |
| 4 | |
| 5 | #include "llama-cparams.h" |
| 6 | |
| 7 | #include <array> |
| 8 | #include <vector> |
| 9 | #include <set> |
| 10 | #include <bitset> |
| 11 | #include <memory> |
| 12 | #include <unordered_map> |
| 13 | |
| 14 | // keep this struct lightweight |
| 15 | struct llama_ubatch { |
| 16 | bool equal_seqs() const { |
| 17 | return b_equal_seqs != 0; |
| 18 | } |
| 19 | |
| 20 | // typical for M-RoPE cases: |
| 21 | // 0 - sequantial position of the tokens/embeddings in the sequence |
| 22 | // 1 - y position in the image |
| 23 | // 2 - x position in the image |
| 24 | // 3 - other |
| 25 | bool is_pos_2d() const { |
| 26 | // TODO @ngxson : we may need to check for model arch when more models use >1 positions |
| 27 | return n_pos >= 3; |
| 28 | } |
| 29 | |
| 30 | uint32_t b_equal_seqs; // note: this is a boolean, but we use an int32_t for alignment |
| 31 | // otherwise address sanitizer complains |
| 32 | // TODO: whole_seqs for embeddings? |
| 33 | |
| 34 | uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs) |
| 35 | uint32_t n_seq_tokens; // tokens per sequence set |
| 36 | uint32_t n_seqs; // sequence sets in the ubatch |
| 37 | uint32_t n_seqs_unq; // unique sequence ids in the ubatch |
| 38 | uint32_t n_pos; // number of position inputs for each token/embedding |
| 39 | |
| 40 | // seq_id_unq: unique sequence ids in the ubatch |
| 41 | // seq_idx: indices of the unique sequence ids in the ubatch in [0, n_seqs_unq) |
| 42 | // used for extracting sequence pooled embeddings |
| 43 | |
| 44 | // // size | idx | val |
| 45 | llama_token * token; // [n_tokens] | i | id, token |
| 46 | float * embd; // [n_embd, n_tokens] | i | embd |
| 47 | llama_pos * pos; // [n_tokens*n_pos] | i | pos |
| 48 | int32_t * n_seq_id; // [n_tokens] | i | - |
| 49 | llama_seq_id ** seq_id; // [n_tokens] | s | s0, s1, seq_id |
| 50 | llama_seq_id * seq_id_unq; // [n_seqs_unq] | s | seq_id |
| 51 | int32_t * seq_idx; // [LLAMA_MAX_SEQ] | - | seq_idx |
| 52 | int8_t * output; // [n_tokens] | i | - |
| 53 | |
| 54 | struct data_t { |
| 55 | std::vector<llama_token> token; |
| 56 | std::vector<float> embd; |
| 57 | std::vector<llama_pos> pos; |
| 58 | std::vector<int32_t> n_seq_id; |
| 59 | std::vector<llama_seq_id *> seq_id; |
| 60 | std::vector<llama_seq_id> seq_id_unq; |
| 61 | std::vector<int32_t> seq_idx; |
| 62 | std::vector<int8_t> output; |
| 63 | }; |
| 64 | |
| 65 | // the llama_ubatch pointers above point to this data if set. otherwise - points to non-owning data |
| 66 | std::shared_ptr<data_t> data; |
| 67 | }; |
| 68 | |
| 69 | // a helper for sanitizing, fulfilling and splitting a batch |
| 70 | class llama_batch_allocr { |
| 71 | public: |
| 72 | llama_batch_allocr(uint32_t n_pos_per_embd); |
| 73 | |
| 74 | // sanitize and auto-gen missing data in the input batch |
| 75 | // memory is optional. if provided will be used to check for sequence continuity and to determine the positions |
| 76 | bool init( |
| 77 | const llama_batch & batch_inp, |
| 78 | const llama_vocab & vocab, |
| 79 | const llama_memory_i * memory, |
| 80 | uint32_t n_embd, |
| 81 | uint32_t n_seq_max, |
| 82 | bool output_all); |
| 83 | |
| 84 | const llama_batch & get_batch() const; |
| 85 | |
| 86 | uint32_t get_n_tokens() const; |
| 87 | uint32_t get_n_outputs() const; |
| 88 | uint32_t get_n_used() const; |
| 89 | |
| 90 | // the array of output indices in the order they were encountered during the ubatch splitting |
| 91 | std::vector<int32_t> & get_out_ids(); |
| 92 | |
| 93 | // min/max positions of each sequence in the current ubatch |
| 94 | llama_pos seq_pos_min(llama_seq_id seq_id) const; |
| 95 | llama_pos seq_pos_max(llama_seq_id seq_id) const; |
| 96 | |
| 97 | // call once before splitting the batch to reset the internal state |
| 98 | void split_reset(); |
| 99 | |
| 100 | // simple split, unknown number of sequence sets of unequal lengths |
| 101 | llama_ubatch split_simple(uint32_t n_ubatch); |
| 102 | |
| 103 | // make ubatches of equal-length sequences sets |
| 104 | // if sequential == true, the tokens in the ubatch will have increasing sequential sequence ids |
| 105 | llama_ubatch split_equal(uint32_t n_ubatch, bool sequential); |
| 106 | |
| 107 | // sequence-set-wise split - each ubatch contains a single sequence-set |
| 108 | llama_ubatch split_seq(uint32_t n_ubatch); |
| 109 | |
| 110 | // a helper method for creating a well-defined ubatch of tokens |
| 111 | // TODO: support embeddings if needed in the future |
| 112 | llama_ubatch ubatch_reserve(uint32_t n_seq_tokens, uint32_t n_seqs); |
| 113 | |
| 114 | private: |
| 115 | void clear(); |
| 116 | |
| 117 | // create the next ubatch based on the provided batch indices (idxs) and the number of sequence sets (n_seqs) |
| 118 | // return llama_ubatch.n_tokens == 0 if the entire batch was consumed |
| 119 | llama_ubatch ubatch_add(const std::vector<int32_t> & idxs, uint32_t n_seqs, bool equal_seqs); |
| 120 | |
| 121 | // for debugging, start with LLAMA_BATCH_DEBUG=2 |
| 122 | void ubatch_print(const llama_ubatch & ubatch, int debug); |
| 123 | |
| 124 | llama_batch batch; |
| 125 | |
| 126 | // only for debugging purposes |
| 127 | const llama_vocab * vocab; |
| 128 | |
| 129 | // TODO: this is more of a temporary solution until we have a better way to handle multiple positions per token/embd |
| 130 | // ref: https://github.com/ggml-org/llama.cpp/issues/13694#issuecomment-2983871762 |
| 131 | const uint32_t n_pos_per_embd; |
| 132 | |
| 133 | uint32_t n_embd; |
| 134 | uint32_t n_seq_max; |
| 135 | uint32_t n_outputs; |
| 136 | |
| 137 | std::array<llama_seq_id, 1> seq_id_0 = {._M_elems: { 0 }}; // default sequence id |
| 138 | |
| 139 | std::vector<llama_pos> pos; |
| 140 | std::vector<int32_t> n_seq_id; |
| 141 | std::vector<llama_seq_id *> seq_id; |
| 142 | std::vector<llama_seq_id> seq_id_unq; |
| 143 | std::vector<int32_t> seq_idx; |
| 144 | std::vector<int8_t> output; |
| 145 | |
| 146 | using pos_set_t = std::set<llama_pos>; |
| 147 | using seq_cpl_t = std::vector<bool>; |
| 148 | |
| 149 | // helper flag to quickly determine if there are any coupled sequences in the batch |
| 150 | bool has_cpl = false; |
| 151 | |
| 152 | std::vector<pos_set_t> seq_pos; // seq_pos[s]: the set of positions in sequence s |
| 153 | std::vector<seq_cpl_t> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1 |
| 154 | |
| 155 | using idx_vec_t = std::vector<int32_t>; |
| 156 | using seq_set_t = std::bitset<LLAMA_MAX_SEQ>; |
| 157 | |
| 158 | std::vector<seq_set_t> seq_set; // seq_set[i]: the sequence set of token i |
| 159 | |
| 160 | std::unordered_map<seq_set_t, idx_vec_t> seq_set_map; // the indices at which the sequence set appears |
| 161 | |
| 162 | // batch indices of the output |
| 163 | std::vector<int32_t> out_ids; |
| 164 | |
| 165 | uint32_t n_used; |
| 166 | |
| 167 | // used[i] indicates if token i has already been used in a previous ubatch |
| 168 | std::vector<bool> used; |
| 169 | |
| 170 | int debug; |
| 171 | }; |
| 172 | |