| 1 | #include "set-rows.cuh" |
| 2 | #include "cpy-utils.cuh" |
| 3 | |
| 4 | typedef void (*set_rows_kernel_t)(const char * src, char * dst); |
| 5 | |
| 6 | // Generic quantized set_rows kernel template |
| 7 | template <typename idx_t, typename block_type, int qk, void (*quantize_func)(const float *, block_type *)> |
| 8 | static __global__ void k_set_rows_quant(const float * __restrict__ src0, |
| 9 | const idx_t * __restrict__ src1, |
| 10 | block_type * __restrict__ dst, |
| 11 | const int64_t ne_total, |
| 12 | const int64_t ne10, |
| 13 | const int64_t ne11, |
| 14 | const int64_t ne12, |
| 15 | const int64_t ne13, |
| 16 | const int64_t s01, |
| 17 | const int64_t s02, |
| 18 | const int64_t s03, |
| 19 | const int64_t s10, |
| 20 | const int64_t s11, |
| 21 | const int64_t s12, |
| 22 | const int64_t s1, |
| 23 | const int64_t s2, |
| 24 | const int64_t s3, |
| 25 | const uint3 ne00, |
| 26 | const uint3 ne01, |
| 27 | const uint3 ne02, |
| 28 | const uint3 ne11_fd, |
| 29 | const uint3 ne12_fd) { |
| 30 | const int64_t i = int64_t(blockDim.x) * blockIdx.x + threadIdx.x; |
| 31 | |
| 32 | if (i >= ne_total) { |
| 33 | return; |
| 34 | } |
| 35 | |
| 36 | const int64_t i_base = i * qk; |
| 37 | uint32_t tmp = (uint32_t) i_base; |
| 38 | uint2 div_mod; |
| 39 | |
| 40 | div_mod = fast_div_modulo(n: tmp, fastdiv_values: ne00); |
| 41 | const int64_t i00 = div_mod.y; |
| 42 | tmp = div_mod.x; |
| 43 | |
| 44 | div_mod = fast_div_modulo(n: tmp, fastdiv_values: ne01); |
| 45 | const int64_t i01 = div_mod.y; |
| 46 | tmp = div_mod.x; |
| 47 | |
| 48 | div_mod = fast_div_modulo(n: tmp, fastdiv_values: ne02); |
| 49 | const int64_t i02 = div_mod.y; |
| 50 | const int64_t i03 = div_mod.x; |
| 51 | |
| 52 | const int64_t i12 = fastmodulo(n: (uint32_t) i03, fastdiv_values: ne12_fd); |
| 53 | const int64_t i11 = fastmodulo(n: (uint32_t) i02, fastdiv_values: ne11_fd); |
| 54 | const int64_t i10 = i01; |
| 55 | |
| 56 | const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12); |
| 57 | |
| 58 | const float * src0_row = src0 + i01*s01 + i02*s02 + i03*s03; |
| 59 | block_type * dst_row_ptr = dst + (dst_row*s1 + i02*s2 + i03*s3) / sizeof(block_type); |
| 60 | |
| 61 | const float * src_block = src0_row + i00; |
| 62 | block_type * dst_block = dst_row_ptr + i00 / qk; |
| 63 | |
| 64 | quantize_func(src_block, dst_block); |
| 65 | |
| 66 | GGML_UNUSED(ne10); |
| 67 | GGML_UNUSED(ne11); |
| 68 | GGML_UNUSED(ne12); |
| 69 | GGML_UNUSED(ne13); |
| 70 | } |
| 71 | |
| 72 | // Template dispatch function for quantized set_rows |
| 73 | template<typename idx_t, typename block_type, int qk, void (*quantize_func)(const float*, block_type*)> |
| 74 | static void set_rows_cuda_quant( |
| 75 | const float * src0_d, const idx_t * src1_d, block_type * dst_d, |
| 76 | const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, |
| 77 | const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13, |
| 78 | const size_t nb01, const size_t nb02, const size_t nb03, |
| 79 | const size_t nb10, const size_t nb11, const size_t nb12, |
| 80 | const size_t nb1, const size_t nb2, const size_t nb3, |
| 81 | cudaStream_t stream) { |
| 82 | |
| 83 | GGML_ASSERT(ne00 % qk == 0); |
| 84 | const int64_t ne_total = (ne00 * ne01 * ne02 * ne03) / qk; |
| 85 | const int num_blocks = (ne_total + CUDA_SET_ROWS_BLOCK_SIZE - 1) / CUDA_SET_ROWS_BLOCK_SIZE; |
| 86 | const dim3 block_size(CUDA_SET_ROWS_BLOCK_SIZE); |
| 87 | const dim3 grid_size(num_blocks); |
| 88 | |
| 89 | const int64_t s01 = nb01/sizeof(float); |
| 90 | const int64_t s02 = nb02/sizeof(float); |
| 91 | const int64_t s03 = nb03/sizeof(float); |
| 92 | const int64_t s10 = nb10/sizeof(idx_t); |
| 93 | const int64_t s11 = nb11/sizeof(idx_t); |
| 94 | const int64_t s12 = nb12/sizeof(idx_t); |
| 95 | const int64_t s1 = nb1; |
| 96 | const int64_t s2 = nb2; |
| 97 | const int64_t s3 = nb3; |
| 98 | |
| 99 | if (ne_total > 0 && ne00 > 0 && ne01 > 0 && ne02 > 0 && ne11 > 0 && ne12 > 0) { |
| 100 | const uint3 ne00_fd = init_fastdiv_values(d_64: (uint32_t) ne00); |
| 101 | const uint3 ne01_fd = init_fastdiv_values(d_64: (uint32_t) ne01); |
| 102 | const uint3 ne02_fd = init_fastdiv_values(d_64: (uint32_t) ne02); |
| 103 | const uint3 ne11_fd = init_fastdiv_values(d_64: (uint32_t) ne11); |
| 104 | const uint3 ne12_fd = init_fastdiv_values(d_64: (uint32_t) ne12); |
| 105 | |
| 106 | k_set_rows_quant<idx_t, block_type, qk, quantize_func><<<gridDim: grid_size, blockDim: block_size, sharedMem: 0, stream>>>( |
| 107 | src0_d, src1_d, dst_d, ne_total, ne10, ne11, ne12, ne13, s01, s02, s03, s10, s11, s12, s1, s2, s3, ne00_fd, |
| 108 | ne01_fd, ne02_fd, ne11_fd, ne12_fd); |
| 109 | } |
| 110 | } |
| 111 | |
| 112 | template <typename src_t, typename idx_t, typename dst_t> |
| 113 | static __global__ void k_set_rows(const src_t * __restrict__ src0, |
| 114 | const idx_t * __restrict__ src1, |
| 115 | dst_t * __restrict__ dst, |
| 116 | const int64_t ne_total, |
| 117 | const int64_t ne10, |
| 118 | const int64_t ne11, |
| 119 | const int64_t ne12, |
| 120 | const int64_t ne13, |
| 121 | const int64_t s01, |
| 122 | const int64_t s02, |
| 123 | const int64_t s03, |
| 124 | const int64_t s10, |
| 125 | const int64_t s11, |
| 126 | const int64_t s12, |
| 127 | const int64_t s1, |
| 128 | const int64_t s2, |
| 129 | const int64_t s3, |
| 130 | const uint3 ne00, |
| 131 | const uint3 ne01, |
| 132 | const uint3 ne02, |
| 133 | const uint3 ne11_fd, |
| 134 | const uint3 ne12_fd) { |
| 135 | const int64_t i = int64_t(blockDim.x) * blockIdx.x + threadIdx.x; |
| 136 | |
| 137 | if (i >= ne_total) { |
| 138 | return; |
| 139 | } |
| 140 | |
| 141 | uint32_t tmp = (uint32_t) i; |
| 142 | uint2 div_mod; |
| 143 | |
| 144 | div_mod = fast_div_modulo(n: tmp, fastdiv_values: ne00); |
| 145 | const int64_t i00 = div_mod.y; |
| 146 | tmp = div_mod.x; |
| 147 | |
| 148 | div_mod = fast_div_modulo(n: tmp, fastdiv_values: ne01); |
| 149 | const int64_t i01 = div_mod.y; |
| 150 | tmp = div_mod.x; |
| 151 | |
| 152 | div_mod = fast_div_modulo(n: tmp, fastdiv_values: ne02); |
| 153 | const int64_t i02 = div_mod.y; |
| 154 | const int64_t i03 = div_mod.x; |
| 155 | |
| 156 | const int64_t i12 = fastmodulo(n: (uint32_t) i03, fastdiv_values: ne12_fd); |
| 157 | const int64_t i11 = fastmodulo(n: (uint32_t) i02, fastdiv_values: ne11_fd); |
| 158 | const int64_t i10 = i01; |
| 159 | |
| 160 | const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12); |
| 161 | |
| 162 | const src_t * src0_row = src0 + i01*s01 + i02*s02 + i03*s03; |
| 163 | dst_t * dst_row_ptr = dst + dst_row*s1 + i02*s2 + i03*s3; |
| 164 | |
| 165 | dst_row_ptr[i00] = ggml_cuda_cast<dst_t>(src0_row[i00]); |
| 166 | |
| 167 | GGML_UNUSED(ne10); |
| 168 | GGML_UNUSED(ne11); |
| 169 | GGML_UNUSED(ne12); |
| 170 | GGML_UNUSED(ne13); |
| 171 | } |
| 172 | |
| 173 | template<typename src_t, typename idx_t, typename dst_t> |
| 174 | static void set_rows_cuda( |
| 175 | const src_t * src0_d, const idx_t * src1_d, dst_t * dst_d, |
| 176 | const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, |
| 177 | const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13, |
| 178 | const size_t nb01, const size_t nb02, const size_t nb03, |
| 179 | const size_t nb10, const size_t nb11, const size_t nb12, |
| 180 | const size_t nb1, const size_t nb2, const size_t nb3, |
| 181 | cudaStream_t stream) { |
| 182 | |
| 183 | const int64_t ne_total = ne00 * ne01 * ne02 * ne03; |
| 184 | const int num_blocks = (ne_total + CUDA_SET_ROWS_BLOCK_SIZE - 1) / CUDA_SET_ROWS_BLOCK_SIZE; |
| 185 | const dim3 block_size(CUDA_SET_ROWS_BLOCK_SIZE); |
| 186 | const dim3 grid_size(num_blocks); |
| 187 | |
| 188 | |
| 189 | const int64_t s01 = nb01/sizeof(src_t); |
| 190 | const int64_t s02 = nb02/sizeof(src_t); |
| 191 | const int64_t s03 = nb03/sizeof(src_t); |
| 192 | const int64_t s10 = nb10/sizeof(idx_t); |
| 193 | const int64_t s11 = nb11/sizeof(idx_t); |
| 194 | const int64_t s12 = nb12/sizeof(idx_t); |
| 195 | const int64_t s1 = nb1/sizeof(dst_t); |
| 196 | const int64_t s2 = nb2/sizeof(dst_t); |
| 197 | const int64_t s3 = nb3/sizeof(dst_t); |
| 198 | |
| 199 | if (ne_total > 0 && ne00 > 0 && ne01 > 0 && ne02 > 0 && ne11 > 0 && ne12 > 0) { |
| 200 | const uint3 ne00_fd = init_fastdiv_values(d_64: (uint32_t) ne00); |
| 201 | const uint3 ne01_fd = init_fastdiv_values(d_64: (uint32_t) ne01); |
| 202 | const uint3 ne02_fd = init_fastdiv_values(d_64: (uint32_t) ne02); |
| 203 | const uint3 ne11_fd = init_fastdiv_values(d_64: (uint32_t) ne11); |
| 204 | const uint3 ne12_fd = init_fastdiv_values(d_64: (uint32_t) ne12); |
| 205 | |
| 206 | k_set_rows<<<gridDim: grid_size, blockDim: block_size, sharedMem: 0, stream>>>(src0_d, src1_d, dst_d, ne_total, ne10, ne11, ne12, ne13, s01, |
| 207 | s02, s03, s10, s11, s12, s1, s2, s3, ne00_fd, ne01_fd, ne02_fd, |
| 208 | ne11_fd, ne12_fd); |
| 209 | } |
| 210 | } |
| 211 | |
| 212 | template<typename src_t, typename idx_t> |
| 213 | static void set_rows_cuda(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { |
| 214 | const src_t * src0_d = (const src_t *)src0->data; |
| 215 | const idx_t * src1_d = (const idx_t *)src1->data; |
| 216 | |
| 217 | GGML_TENSOR_BINARY_OP_LOCALS |
| 218 | |
| 219 | cudaStream_t stream = ctx.stream(); |
| 220 | |
| 221 | |
| 222 | if (dst->type == GGML_TYPE_F32) { |
| 223 | set_rows_cuda( |
| 224 | src0_d, src1_d, (float*)dst->data, |
| 225 | ne00, ne01, ne02, ne03, |
| 226 | ne10, ne11, ne12, ne13, |
| 227 | nb01, nb02, nb03, |
| 228 | nb10, nb11, nb12, |
| 229 | nb1, nb2, nb3, |
| 230 | stream |
| 231 | ); |
| 232 | } else if (dst->type == GGML_TYPE_F16) { |
| 233 | set_rows_cuda( |
| 234 | src0_d, src1_d, (half*)dst->data, |
| 235 | ne00, ne01, ne02, ne03, |
| 236 | ne10, ne11, ne12, ne13, |
| 237 | nb01, nb02, nb03, |
| 238 | nb10, nb11, nb12, |
| 239 | nb1, nb2, nb3, |
| 240 | stream |
| 241 | ); |
| 242 | } else if (dst->type == GGML_TYPE_BF16) { |
| 243 | set_rows_cuda( |
| 244 | src0_d, src1_d, (nv_bfloat16*)dst->data, |
| 245 | ne00, ne01, ne02, ne03, |
| 246 | ne10, ne11, ne12, ne13, |
| 247 | nb01, nb02, nb03, |
| 248 | nb10, nb11, nb12, |
| 249 | nb1, nb2, nb3, |
| 250 | stream |
| 251 | ); |
| 252 | } else if (dst->type == GGML_TYPE_Q4_0) { |
| 253 | set_rows_cuda_quant<idx_t, block_q4_0, QK4_0, quantize_f32_q4_0_block>( |
| 254 | src0_d, src1_d, (block_q4_0*)dst->data, |
| 255 | ne00, ne01, ne02, ne03, |
| 256 | ne10, ne11, ne12, ne13, |
| 257 | nb01, nb02, nb03, |
| 258 | nb10, nb11, nb12, |
| 259 | nb1, nb2, nb3, |
| 260 | stream |
| 261 | ); |
| 262 | } else if (dst->type == GGML_TYPE_Q4_1) { |
| 263 | set_rows_cuda_quant<idx_t, block_q4_1, QK4_1, quantize_f32_q4_1_block>( |
| 264 | src0_d, src1_d, (block_q4_1*)dst->data, |
| 265 | ne00, ne01, ne02, ne03, |
| 266 | ne10, ne11, ne12, ne13, |
| 267 | nb01, nb02, nb03, |
| 268 | nb10, nb11, nb12, |
| 269 | nb1, nb2, nb3, |
| 270 | stream |
| 271 | ); |
| 272 | } else if (dst->type == GGML_TYPE_Q5_0) { |
| 273 | set_rows_cuda_quant<idx_t, block_q5_0, QK5_0, quantize_f32_q5_0_block>( |
| 274 | src0_d, src1_d, (block_q5_0*)dst->data, |
| 275 | ne00, ne01, ne02, ne03, |
| 276 | ne10, ne11, ne12, ne13, |
| 277 | nb01, nb02, nb03, |
| 278 | nb10, nb11, nb12, |
| 279 | nb1, nb2, nb3, |
| 280 | stream |
| 281 | ); |
| 282 | } else if (dst->type == GGML_TYPE_Q5_1) { |
| 283 | set_rows_cuda_quant<idx_t, block_q5_1, QK5_1, quantize_f32_q5_1_block>( |
| 284 | src0_d, src1_d, (block_q5_1*)dst->data, |
| 285 | ne00, ne01, ne02, ne03, |
| 286 | ne10, ne11, ne12, ne13, |
| 287 | nb01, nb02, nb03, |
| 288 | nb10, nb11, nb12, |
| 289 | nb1, nb2, nb3, |
| 290 | stream |
| 291 | ); |
| 292 | } else if (dst->type == GGML_TYPE_Q8_0) { |
| 293 | set_rows_cuda_quant<idx_t, block_q8_0, QK8_0, quantize_f32_q8_0_block>( |
| 294 | src0_d, src1_d, (block_q8_0*)dst->data, |
| 295 | ne00, ne01, ne02, ne03, |
| 296 | ne10, ne11, ne12, ne13, |
| 297 | nb01, nb02, nb03, |
| 298 | nb10, nb11, nb12, |
| 299 | nb1, nb2, nb3, |
| 300 | stream |
| 301 | ); |
| 302 | } else if (dst->type == GGML_TYPE_IQ4_NL) { |
| 303 | set_rows_cuda_quant<idx_t, block_iq4_nl, QK4_NL, quantize_f32_iq4_nl_block>( |
| 304 | src0_d, src1_d, (block_iq4_nl*)dst->data, |
| 305 | ne00, ne01, ne02, ne03, |
| 306 | ne10, ne11, ne12, ne13, |
| 307 | nb01, nb02, nb03, |
| 308 | nb10, nb11, nb12, |
| 309 | nb1, nb2, nb3, |
| 310 | stream |
| 311 | ); |
| 312 | } else { |
| 313 | GGML_ABORT("unsupported type %s" , ggml_type_name(dst->type)); |
| 314 | } |
| 315 | } |
| 316 | |
| 317 | |
| 318 | void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
| 319 | const ggml_tensor * src0 = dst->src[0]; |
| 320 | const ggml_tensor * src1 = dst->src[1]; |
| 321 | |
| 322 | GGML_ASSERT(src0->type == GGML_TYPE_F32); |
| 323 | GGML_ASSERT(src1->type == GGML_TYPE_I64 || src1->type == GGML_TYPE_I32); |
| 324 | |
| 325 | if (src1->type == GGML_TYPE_I64) { |
| 326 | set_rows_cuda<float, int64_t>(ctx, src0, src1, dst); |
| 327 | } else { |
| 328 | set_rows_cuda<float, int32_t>(ctx, src0, src1, dst); |
| 329 | } |
| 330 | } |
| 331 | |