1#include "set-rows.cuh"
2#include "cpy-utils.cuh"
3
4typedef void (*set_rows_kernel_t)(const char * src, char * dst);
5
6// Generic quantized set_rows kernel template
7template <typename idx_t, typename block_type, int qk, void (*quantize_func)(const float *, block_type *)>
8static __global__ void k_set_rows_quant(const float * __restrict__ src0,
9 const idx_t * __restrict__ src1,
10 block_type * __restrict__ dst,
11 const int64_t ne_total,
12 const int64_t ne10,
13 const int64_t ne11,
14 const int64_t ne12,
15 const int64_t ne13,
16 const int64_t s01,
17 const int64_t s02,
18 const int64_t s03,
19 const int64_t s10,
20 const int64_t s11,
21 const int64_t s12,
22 const int64_t s1,
23 const int64_t s2,
24 const int64_t s3,
25 const uint3 ne00,
26 const uint3 ne01,
27 const uint3 ne02,
28 const uint3 ne11_fd,
29 const uint3 ne12_fd) {
30 const int64_t i = int64_t(blockDim.x) * blockIdx.x + threadIdx.x;
31
32 if (i >= ne_total) {
33 return;
34 }
35
36 const int64_t i_base = i * qk;
37 uint32_t tmp = (uint32_t) i_base;
38 uint2 div_mod;
39
40 div_mod = fast_div_modulo(n: tmp, fastdiv_values: ne00);
41 const int64_t i00 = div_mod.y;
42 tmp = div_mod.x;
43
44 div_mod = fast_div_modulo(n: tmp, fastdiv_values: ne01);
45 const int64_t i01 = div_mod.y;
46 tmp = div_mod.x;
47
48 div_mod = fast_div_modulo(n: tmp, fastdiv_values: ne02);
49 const int64_t i02 = div_mod.y;
50 const int64_t i03 = div_mod.x;
51
52 const int64_t i12 = fastmodulo(n: (uint32_t) i03, fastdiv_values: ne12_fd);
53 const int64_t i11 = fastmodulo(n: (uint32_t) i02, fastdiv_values: ne11_fd);
54 const int64_t i10 = i01;
55
56 const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12);
57
58 const float * src0_row = src0 + i01*s01 + i02*s02 + i03*s03;
59 block_type * dst_row_ptr = dst + (dst_row*s1 + i02*s2 + i03*s3) / sizeof(block_type);
60
61 const float * src_block = src0_row + i00;
62 block_type * dst_block = dst_row_ptr + i00 / qk;
63
64 quantize_func(src_block, dst_block);
65
66 GGML_UNUSED(ne10);
67 GGML_UNUSED(ne11);
68 GGML_UNUSED(ne12);
69 GGML_UNUSED(ne13);
70}
71
72// Template dispatch function for quantized set_rows
73template<typename idx_t, typename block_type, int qk, void (*quantize_func)(const float*, block_type*)>
74static void set_rows_cuda_quant(
75 const float * src0_d, const idx_t * src1_d, block_type * dst_d,
76 const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
77 const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
78 const size_t nb01, const size_t nb02, const size_t nb03,
79 const size_t nb10, const size_t nb11, const size_t nb12,
80 const size_t nb1, const size_t nb2, const size_t nb3,
81 cudaStream_t stream) {
82
83 GGML_ASSERT(ne00 % qk == 0);
84 const int64_t ne_total = (ne00 * ne01 * ne02 * ne03) / qk;
85 const int num_blocks = (ne_total + CUDA_SET_ROWS_BLOCK_SIZE - 1) / CUDA_SET_ROWS_BLOCK_SIZE;
86 const dim3 block_size(CUDA_SET_ROWS_BLOCK_SIZE);
87 const dim3 grid_size(num_blocks);
88
89 const int64_t s01 = nb01/sizeof(float);
90 const int64_t s02 = nb02/sizeof(float);
91 const int64_t s03 = nb03/sizeof(float);
92 const int64_t s10 = nb10/sizeof(idx_t);
93 const int64_t s11 = nb11/sizeof(idx_t);
94 const int64_t s12 = nb12/sizeof(idx_t);
95 const int64_t s1 = nb1;
96 const int64_t s2 = nb2;
97 const int64_t s3 = nb3;
98
99 if (ne_total > 0 && ne00 > 0 && ne01 > 0 && ne02 > 0 && ne11 > 0 && ne12 > 0) {
100 const uint3 ne00_fd = init_fastdiv_values(d_64: (uint32_t) ne00);
101 const uint3 ne01_fd = init_fastdiv_values(d_64: (uint32_t) ne01);
102 const uint3 ne02_fd = init_fastdiv_values(d_64: (uint32_t) ne02);
103 const uint3 ne11_fd = init_fastdiv_values(d_64: (uint32_t) ne11);
104 const uint3 ne12_fd = init_fastdiv_values(d_64: (uint32_t) ne12);
105
106 k_set_rows_quant<idx_t, block_type, qk, quantize_func><<<gridDim: grid_size, blockDim: block_size, sharedMem: 0, stream>>>(
107 src0_d, src1_d, dst_d, ne_total, ne10, ne11, ne12, ne13, s01, s02, s03, s10, s11, s12, s1, s2, s3, ne00_fd,
108 ne01_fd, ne02_fd, ne11_fd, ne12_fd);
109 }
110}
111
112template <typename src_t, typename idx_t, typename dst_t>
113static __global__ void k_set_rows(const src_t * __restrict__ src0,
114 const idx_t * __restrict__ src1,
115 dst_t * __restrict__ dst,
116 const int64_t ne_total,
117 const int64_t ne10,
118 const int64_t ne11,
119 const int64_t ne12,
120 const int64_t ne13,
121 const int64_t s01,
122 const int64_t s02,
123 const int64_t s03,
124 const int64_t s10,
125 const int64_t s11,
126 const int64_t s12,
127 const int64_t s1,
128 const int64_t s2,
129 const int64_t s3,
130 const uint3 ne00,
131 const uint3 ne01,
132 const uint3 ne02,
133 const uint3 ne11_fd,
134 const uint3 ne12_fd) {
135 const int64_t i = int64_t(blockDim.x) * blockIdx.x + threadIdx.x;
136
137 if (i >= ne_total) {
138 return;
139 }
140
141 uint32_t tmp = (uint32_t) i;
142 uint2 div_mod;
143
144 div_mod = fast_div_modulo(n: tmp, fastdiv_values: ne00);
145 const int64_t i00 = div_mod.y;
146 tmp = div_mod.x;
147
148 div_mod = fast_div_modulo(n: tmp, fastdiv_values: ne01);
149 const int64_t i01 = div_mod.y;
150 tmp = div_mod.x;
151
152 div_mod = fast_div_modulo(n: tmp, fastdiv_values: ne02);
153 const int64_t i02 = div_mod.y;
154 const int64_t i03 = div_mod.x;
155
156 const int64_t i12 = fastmodulo(n: (uint32_t) i03, fastdiv_values: ne12_fd);
157 const int64_t i11 = fastmodulo(n: (uint32_t) i02, fastdiv_values: ne11_fd);
158 const int64_t i10 = i01;
159
160 const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12);
161
162 const src_t * src0_row = src0 + i01*s01 + i02*s02 + i03*s03;
163 dst_t * dst_row_ptr = dst + dst_row*s1 + i02*s2 + i03*s3;
164
165 dst_row_ptr[i00] = ggml_cuda_cast<dst_t>(src0_row[i00]);
166
167 GGML_UNUSED(ne10);
168 GGML_UNUSED(ne11);
169 GGML_UNUSED(ne12);
170 GGML_UNUSED(ne13);
171}
172
173template<typename src_t, typename idx_t, typename dst_t>
174static void set_rows_cuda(
175 const src_t * src0_d, const idx_t * src1_d, dst_t * dst_d,
176 const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
177 const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
178 const size_t nb01, const size_t nb02, const size_t nb03,
179 const size_t nb10, const size_t nb11, const size_t nb12,
180 const size_t nb1, const size_t nb2, const size_t nb3,
181 cudaStream_t stream) {
182
183 const int64_t ne_total = ne00 * ne01 * ne02 * ne03;
184 const int num_blocks = (ne_total + CUDA_SET_ROWS_BLOCK_SIZE - 1) / CUDA_SET_ROWS_BLOCK_SIZE;
185 const dim3 block_size(CUDA_SET_ROWS_BLOCK_SIZE);
186 const dim3 grid_size(num_blocks);
187
188
189 const int64_t s01 = nb01/sizeof(src_t);
190 const int64_t s02 = nb02/sizeof(src_t);
191 const int64_t s03 = nb03/sizeof(src_t);
192 const int64_t s10 = nb10/sizeof(idx_t);
193 const int64_t s11 = nb11/sizeof(idx_t);
194 const int64_t s12 = nb12/sizeof(idx_t);
195 const int64_t s1 = nb1/sizeof(dst_t);
196 const int64_t s2 = nb2/sizeof(dst_t);
197 const int64_t s3 = nb3/sizeof(dst_t);
198
199 if (ne_total > 0 && ne00 > 0 && ne01 > 0 && ne02 > 0 && ne11 > 0 && ne12 > 0) {
200 const uint3 ne00_fd = init_fastdiv_values(d_64: (uint32_t) ne00);
201 const uint3 ne01_fd = init_fastdiv_values(d_64: (uint32_t) ne01);
202 const uint3 ne02_fd = init_fastdiv_values(d_64: (uint32_t) ne02);
203 const uint3 ne11_fd = init_fastdiv_values(d_64: (uint32_t) ne11);
204 const uint3 ne12_fd = init_fastdiv_values(d_64: (uint32_t) ne12);
205
206 k_set_rows<<<gridDim: grid_size, blockDim: block_size, sharedMem: 0, stream>>>(src0_d, src1_d, dst_d, ne_total, ne10, ne11, ne12, ne13, s01,
207 s02, s03, s10, s11, s12, s1, s2, s3, ne00_fd, ne01_fd, ne02_fd,
208 ne11_fd, ne12_fd);
209 }
210}
211
212template<typename src_t, typename idx_t>
213static void set_rows_cuda(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
214 const src_t * src0_d = (const src_t *)src0->data;
215 const idx_t * src1_d = (const idx_t *)src1->data;
216
217 GGML_TENSOR_BINARY_OP_LOCALS
218
219 cudaStream_t stream = ctx.stream();
220
221
222 if (dst->type == GGML_TYPE_F32) {
223 set_rows_cuda(
224 src0_d, src1_d, (float*)dst->data,
225 ne00, ne01, ne02, ne03,
226 ne10, ne11, ne12, ne13,
227 nb01, nb02, nb03,
228 nb10, nb11, nb12,
229 nb1, nb2, nb3,
230 stream
231 );
232 } else if (dst->type == GGML_TYPE_F16) {
233 set_rows_cuda(
234 src0_d, src1_d, (half*)dst->data,
235 ne00, ne01, ne02, ne03,
236 ne10, ne11, ne12, ne13,
237 nb01, nb02, nb03,
238 nb10, nb11, nb12,
239 nb1, nb2, nb3,
240 stream
241 );
242 } else if (dst->type == GGML_TYPE_BF16) {
243 set_rows_cuda(
244 src0_d, src1_d, (nv_bfloat16*)dst->data,
245 ne00, ne01, ne02, ne03,
246 ne10, ne11, ne12, ne13,
247 nb01, nb02, nb03,
248 nb10, nb11, nb12,
249 nb1, nb2, nb3,
250 stream
251 );
252 } else if (dst->type == GGML_TYPE_Q4_0) {
253 set_rows_cuda_quant<idx_t, block_q4_0, QK4_0, quantize_f32_q4_0_block>(
254 src0_d, src1_d, (block_q4_0*)dst->data,
255 ne00, ne01, ne02, ne03,
256 ne10, ne11, ne12, ne13,
257 nb01, nb02, nb03,
258 nb10, nb11, nb12,
259 nb1, nb2, nb3,
260 stream
261 );
262 } else if (dst->type == GGML_TYPE_Q4_1) {
263 set_rows_cuda_quant<idx_t, block_q4_1, QK4_1, quantize_f32_q4_1_block>(
264 src0_d, src1_d, (block_q4_1*)dst->data,
265 ne00, ne01, ne02, ne03,
266 ne10, ne11, ne12, ne13,
267 nb01, nb02, nb03,
268 nb10, nb11, nb12,
269 nb1, nb2, nb3,
270 stream
271 );
272 } else if (dst->type == GGML_TYPE_Q5_0) {
273 set_rows_cuda_quant<idx_t, block_q5_0, QK5_0, quantize_f32_q5_0_block>(
274 src0_d, src1_d, (block_q5_0*)dst->data,
275 ne00, ne01, ne02, ne03,
276 ne10, ne11, ne12, ne13,
277 nb01, nb02, nb03,
278 nb10, nb11, nb12,
279 nb1, nb2, nb3,
280 stream
281 );
282 } else if (dst->type == GGML_TYPE_Q5_1) {
283 set_rows_cuda_quant<idx_t, block_q5_1, QK5_1, quantize_f32_q5_1_block>(
284 src0_d, src1_d, (block_q5_1*)dst->data,
285 ne00, ne01, ne02, ne03,
286 ne10, ne11, ne12, ne13,
287 nb01, nb02, nb03,
288 nb10, nb11, nb12,
289 nb1, nb2, nb3,
290 stream
291 );
292 } else if (dst->type == GGML_TYPE_Q8_0) {
293 set_rows_cuda_quant<idx_t, block_q8_0, QK8_0, quantize_f32_q8_0_block>(
294 src0_d, src1_d, (block_q8_0*)dst->data,
295 ne00, ne01, ne02, ne03,
296 ne10, ne11, ne12, ne13,
297 nb01, nb02, nb03,
298 nb10, nb11, nb12,
299 nb1, nb2, nb3,
300 stream
301 );
302 } else if (dst->type == GGML_TYPE_IQ4_NL) {
303 set_rows_cuda_quant<idx_t, block_iq4_nl, QK4_NL, quantize_f32_iq4_nl_block>(
304 src0_d, src1_d, (block_iq4_nl*)dst->data,
305 ne00, ne01, ne02, ne03,
306 ne10, ne11, ne12, ne13,
307 nb01, nb02, nb03,
308 nb10, nb11, nb12,
309 nb1, nb2, nb3,
310 stream
311 );
312 } else {
313 GGML_ABORT("unsupported type %s", ggml_type_name(dst->type));
314 }
315}
316
317
318void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
319 const ggml_tensor * src0 = dst->src[0];
320 const ggml_tensor * src1 = dst->src[1];
321
322 GGML_ASSERT(src0->type == GGML_TYPE_F32);
323 GGML_ASSERT(src1->type == GGML_TYPE_I64 || src1->type == GGML_TYPE_I32);
324
325 if (src1->type == GGML_TYPE_I64) {
326 set_rows_cuda<float, int64_t>(ctx, src0, src1, dst);
327 } else {
328 set_rows_cuda<float, int32_t>(ctx, src0, src1, dst);
329 }
330}
331