1#include "cpy.cuh"
2#include "dequantize.cuh"
3#include "cpy-utils.cuh"
4#if defined(GGML_USE_MUSA) && defined(GGML_MUSA_MUDNN_COPY)
5#include "ggml-musa/mudnn.cuh"
6#endif // GGML_USE_MUSA && GGML_MUSA_MUDNN_COPY
7
8typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
9
10const int CUDA_CPY_TILE_DIM_2D = 32; // 2D tile dimension for transposed blocks
11const int CUDA_CPY_BLOCK_NM = 8; // block size of 3rd dimension if available
12const int CUDA_CPY_BLOCK_ROWS = 8; // block dimension for marching through rows
13
14template <cpy_kernel_t cpy_1>
15static __global__ void cpy_flt(const char * cx, char * cdst, const int ne,
16 const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
17 const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
18 const int nb12, const int nb13) {
19 const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
20
21 if (i >= ne) {
22 return;
23 }
24
25 // determine indices i03/i13, i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor
26 // then combine those indices with the corresponding byte offsets to get the total offsets
27 const int64_t i03 = i/(ne00 * ne01 * ne02);
28 const int64_t i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
29 const int64_t i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
30 const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
31 const int64_t x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
32
33 const int64_t i13 = i/(ne10 * ne11 * ne12);
34 const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
35 const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
36 const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
37 const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13 * nb13;
38
39 cpy_1(cx + x_offset, cdst + dst_offset);
40}
41
42template <typename T>
43static __global__ void cpy_flt_transpose(const char * cx, char * cdst, const int ne,
44 const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
45 const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
46 const int nb12, const int nb13) {
47
48 const T* src = reinterpret_cast<const T*>(cx);
49 T* dst = reinterpret_cast<T*>(cdst);
50
51 const int64_t nmat = ne / (ne00 * ne01);
52 const int64_t n = ne00 * ne01;
53
54 const int x = blockIdx.x * CUDA_CPY_TILE_DIM_2D + threadIdx.x;
55 const int y = blockIdx.y * CUDA_CPY_TILE_DIM_2D + threadIdx.y;
56 const int tx = blockIdx.y * CUDA_CPY_TILE_DIM_2D + threadIdx.x; // transpose block offset
57 const int ty = blockIdx.x * CUDA_CPY_TILE_DIM_2D + threadIdx.y;
58
59 __shared__ float tile[CUDA_CPY_TILE_DIM_2D][CUDA_CPY_TILE_DIM_2D+1];
60
61#pragma unroll
62 for (int i = 0; i < CUDA_CPY_BLOCK_NM; ++i) {
63
64 const unsigned int imat = blockIdx.z * CUDA_CPY_BLOCK_NM + i;
65 if (imat >= nmat)
66 break;
67
68#pragma unroll
69 for (int j = 0; j < CUDA_CPY_TILE_DIM_2D; j += CUDA_CPY_BLOCK_ROWS) {
70 if(x < ne01 && y + j < ne00){
71 const int row = threadIdx.y+j;
72 const int col = threadIdx.x * sizeof(float)/sizeof(T);
73 T *tile2 = reinterpret_cast<T*>(tile[row]);
74 tile2[col] = src[imat*n + (y+j)*ne01 + x];
75 }
76 }
77
78 __syncthreads();
79
80#pragma unroll
81 for (int j = 0; j < CUDA_CPY_TILE_DIM_2D; j += CUDA_CPY_BLOCK_ROWS) {
82 if (ty + j < ne01 && tx < ne00) {
83 const int col = (threadIdx.y+j)*sizeof(float)/sizeof(T);
84 const T *tile2 = reinterpret_cast<const T*>(tile[threadIdx.x]);
85 dst[imat*n + (ty+j)*ne00 + tx] = tile2[col];
86 }
87 }
88 }
89}
90
91static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) {
92 float * cdstf = (float *)(cdsti);
93
94#pragma unroll
95 for (int j = 0; j < QK8_0; j += 2) {
96 float2 dq;
97 dequantize_q8_0(vx: cxi, ib: 0, iqs: j, v&: dq);
98 *(cdstf + j) = dq.x;
99 *(cdstf + j + 1) = dq.y;
100 }
101}
102
103template<dequantize_kernel_t dequant, int qk>
104static __device__ void cpy_blck_q_f32(const char * cxi, char * cdsti) {
105 float * cdstf = (float *)(cdsti);
106
107#pragma unroll
108 for (int j = 0; j < qk/2; j++) {
109 float2 dq;
110 dequant(cxi, 0, j, dq);
111 *(cdstf + j) = dq.x;
112 *(cdstf + j + qk/2) = dq.y;
113 }
114}
115
116template <cpy_kernel_t cpy_blck, int qk>
117static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne,
118 const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
119 const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
120 const int nb12, const int nb13) {
121 const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk;
122
123 if (i >= ne) {
124 return;
125 }
126
127 const int i03 = i/(ne00 * ne01 * ne02);
128 const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
129 const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
130 const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
131 const int x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
132
133 const int i13 = i/(ne10 * ne11 * ne12);
134 const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
135 const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
136 const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
137 const int dst_offset = (i10/qk)*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
138
139 cpy_blck(cx + x_offset, cdst + dst_offset);
140}
141
142template <cpy_kernel_t cpy_blck, int qk>
143static __global__ void cpy_q_f32(const char * cx, char * cdst, const int ne,
144 const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
145 const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
146 const int nb12, const int nb13) {
147 const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk;
148
149 if (i >= ne) {
150 return;
151 }
152
153 const int i03 = i/(ne00 * ne01 * ne02);
154 const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
155 const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
156 const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
157 const int x_offset = (i00/qk)*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
158
159 const int i13 = i/(ne10 * ne11 * ne12);
160 const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
161 const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
162 const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
163 const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
164
165 cpy_blck(cx + x_offset, cdst + dst_offset);
166}
167
168template<typename src_t, typename dst_t>
169static __global__ void cpy_flt_contiguous(const char * cx, char * cdst, const int64_t ne) {
170 const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
171
172 if (i >= ne) {
173 return;
174 }
175
176 const src_t * x = (const src_t *) cx;
177 dst_t * dst = (dst_t *) cdst;
178
179 dst[i] = ggml_cuda_cast<dst_t>(x[i]);
180}
181
182template<typename src_t, typename dst_t>
183static void ggml_cpy_flt_contiguous_cuda(
184 const char * cx, char * cdst, const int64_t ne,
185cudaStream_t stream) {
186
187 const int64_t num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
188 cpy_flt_contiguous<src_t, dst_t><<<gridDim: num_blocks, CUDA_CPY_BLOCK_SIZE, sharedMem: 0, stream>>>
189 (cx, cdst, ne);
190}
191
192template<typename src_t, typename dst_t, bool transposed = false>
193static void ggml_cpy_flt_cuda(
194 const char * cx, char * cdst, const int ne,
195 const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
196 const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
197
198 if (transposed) {
199 GGML_ASSERT(ne == ne00*ne01*ne02); // ne[3] is 1 assumed
200 int ne00n, ne01n, ne02n;
201 if (nb00 <= nb02) { // most likely safe to handle nb00 = nb02 case here
202 ne00n = ne00;
203 ne01n = ne01;
204 ne02n = ne02;
205 } else if (nb00 > nb02) {
206 ne00n = ne00;
207 ne01n = ne01*ne02;
208 ne02n = 1;
209 }
210
211 dim3 dimGrid( (ne01n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D,
212 (ne00n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D,
213 (ne/(ne01n*ne00n) + CUDA_CPY_BLOCK_NM - 1) / CUDA_CPY_BLOCK_NM);
214 dim3 dimBlock(CUDA_CPY_TILE_DIM_2D, CUDA_CPY_BLOCK_ROWS, 1);
215 cpy_flt_transpose<dst_t><<<gridDim: dimGrid, blockDim: dimBlock, sharedMem: 0, stream>>>
216 (cx, cdst, ne, ne00n, ne01n, ne02n, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
217 } else {
218 const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
219 cpy_flt<cpy_1_flt<src_t, dst_t>><<<gridDim: num_blocks, CUDA_CPY_BLOCK_SIZE, sharedMem: 0, stream>>>
220 (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
221 }
222}
223
224static void ggml_cpy_f32_q8_0_cuda(
225 const char * cx, char * cdst, const int ne,
226 const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
227 const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
228
229 GGML_ASSERT(ne % QK8_0 == 0);
230 const int num_blocks = ne / QK8_0;
231 cpy_f32_q<cpy_blck_f32_q8_0, QK8_0><<<num_blocks, 1, 0, stream>>>
232 (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
233}
234
235static void ggml_cpy_q8_0_f32_cuda(
236 const char * cx, char * cdst, const int ne,
237 const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
238 const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
239
240 const int num_blocks = ne;
241 cpy_q_f32<cpy_blck_q8_0_f32, QK8_0><<<num_blocks, 1, 0, stream>>>
242 (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
243}
244
245static void ggml_cpy_f32_q4_0_cuda(
246 const char * cx, char * cdst, const int ne,
247 const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
248 const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
249
250 GGML_ASSERT(ne % QK4_0 == 0);
251 const int num_blocks = ne / QK4_0;
252 cpy_f32_q<cpy_blck_f32_q4_0, QK4_0><<<num_blocks, 1, 0, stream>>>
253 (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
254}
255
256static void ggml_cpy_q4_0_f32_cuda(
257 const char * cx, char * cdst, const int ne,
258 const int ne00, const int ne01, const int ne02,
259 const int nb00, const int nb01, const int nb02,
260 const int nb03, const int ne10, const int ne11, const int ne12,
261 const int nb10, const int nb11, const int nb12, const int nb13,
262 cudaStream_t stream) {
263 const int num_blocks = ne;
264 cpy_q_f32<cpy_blck_q_f32<dequantize_q4_0, QK4_0>, QK4_0><<<num_blocks, 1, 0, stream>>>(
265 cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
266 ne10, ne11, ne12, nb10, nb11, nb12, nb13);
267}
268
269static void ggml_cpy_f32_q4_1_cuda(
270 const char * cx, char * cdst, const int ne,
271 const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
272 const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
273
274 GGML_ASSERT(ne % QK4_1 == 0);
275 const int num_blocks = ne / QK4_1;
276 cpy_f32_q<cpy_blck_f32_q4_1, QK4_1><<<num_blocks, 1, 0, stream>>>
277 (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
278}
279
280static void ggml_cpy_q4_1_f32_cuda(
281 const char * cx, char * cdst, const int ne,
282 const int ne00, const int ne01, const int ne02,
283 const int nb00, const int nb01, const int nb02,
284 const int nb03, const int ne10, const int ne11, const int ne12,
285 const int nb10, const int nb11, const int nb12, const int nb13,
286 cudaStream_t stream) {
287 const int num_blocks = ne;
288 cpy_q_f32<cpy_blck_q_f32<dequantize_q4_1, QK4_1>, QK4_1><<<num_blocks, 1, 0, stream>>>(
289 cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
290 ne10, ne11, ne12, nb10, nb11, nb12, nb13);
291}
292
293static void ggml_cpy_f32_q5_0_cuda(
294 const char * cx, char * cdst, const int ne,
295 const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
296 const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
297
298 GGML_ASSERT(ne % QK5_0 == 0);
299 const int num_blocks = ne / QK5_0;
300 cpy_f32_q<cpy_blck_f32_q5_0, QK5_0><<<num_blocks, 1, 0, stream>>>
301 (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
302}
303
304static void ggml_cpy_q5_0_f32_cuda(
305 const char * cx, char * cdst, const int ne,
306 const int ne00, const int ne01, const int ne02,
307 const int nb00, const int nb01, const int nb02,
308 const int nb03, const int ne10, const int ne11, const int ne12,
309 const int nb10, const int nb11, const int nb12, const int nb13,
310 cudaStream_t stream) {
311 const int num_blocks = ne;
312 cpy_q_f32<cpy_blck_q_f32<dequantize_q5_0, QK5_0>, QK5_0><<<num_blocks, 1, 0, stream>>>(
313 cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
314 ne10, ne11, ne12, nb10, nb11, nb12, nb13);
315}
316
317static void ggml_cpy_f32_q5_1_cuda(
318 const char * cx, char * cdst, const int ne,
319 const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
320 const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
321
322 GGML_ASSERT(ne % QK5_1 == 0);
323 const int num_blocks = ne / QK5_1;
324 cpy_f32_q<cpy_blck_f32_q5_1, QK5_1><<<num_blocks, 1, 0, stream>>>
325 (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
326}
327
328static void ggml_cpy_q5_1_f32_cuda(
329 const char * cx, char * cdst, const int ne,
330 const int ne00, const int ne01, const int ne02,
331 const int nb00, const int nb01, const int nb02,
332 const int nb03, const int ne10, const int ne11, const int ne12,
333 const int nb10, const int nb11, const int nb12, const int nb13,
334 cudaStream_t stream) {
335 const int num_blocks = ne;
336 cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1><<<num_blocks, 1, 0, stream>>>(
337 cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
338 ne10, ne11, ne12, nb10, nb11, nb12, nb13);
339}
340
341static void ggml_cpy_f32_iq4_nl_cuda(
342 const char * cx, char * cdst, const int ne,
343 const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
344 const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
345
346 GGML_ASSERT(ne % QK4_NL == 0);
347 const int num_blocks = ne / QK4_NL;
348 cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL><<<num_blocks, 1, 0, stream>>>
349 (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
350}
351
352void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1) {
353 const int64_t ne = ggml_nelements(src0);
354 GGML_ASSERT(ne == ggml_nelements(src1));
355
356 GGML_ASSERT(ggml_nbytes(src0) <= INT_MAX);
357 GGML_ASSERT(ggml_nbytes(src1) <= INT_MAX);
358
359 const int64_t ne00 = src0->ne[0];
360 const int64_t ne01 = src0->ne[1];
361 const int64_t ne02 = src0->ne[2];
362
363 //GGML_ASSERT(src0->ne[3] == 1);
364
365 const int64_t nb00 = src0->nb[0];
366 const int64_t nb01 = src0->nb[1];
367 const int64_t nb02 = src0->nb[2];
368 const int64_t nb03 = src0->nb[3];
369
370 const int64_t ne10 = src1->ne[0];
371 const int64_t ne11 = src1->ne[1];
372 const int64_t ne12 = src1->ne[2];
373
374 //GGML_ASSERT(src1->ne[3] == 1);
375
376 const int64_t nb10 = src1->nb[0];
377 const int64_t nb11 = src1->nb[1];
378 const int64_t nb12 = src1->nb[2];
379 const int64_t nb13 = src1->nb[3];
380
381 cudaStream_t main_stream = ctx.stream();
382
383 char * src0_ddc = (char *) src0->data;
384 char * src1_ddc = (char *) src1->data;
385
386 const bool contiguous_srcs = ggml_is_contiguous(src0) && ggml_is_contiguous(src1);
387 const bool can_be_transposed = nb01 == (int64_t)ggml_element_size(src0) && src0->ne[3] == 1;
388
389 if (src0->type == src1->type && contiguous_srcs) {
390 GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1));
391#if defined(GGML_USE_MUSA) && defined(GGML_MUSA_MUDNN_COPY)
392 if (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16) {
393 CUDA_CHECK(mudnnMemcpyAsync(ctx, src1, src0));
394 } else
395#endif // GGML_USE_MUSA && GGML_MUSA_MUDNN_COPY
396 {
397 CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
398 }
399 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
400 if (can_be_transposed) {
401 ggml_cpy_flt_cuda<float, float, true> (cx: src0_ddc, cdst: src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, stream: main_stream);
402 } else {
403 ggml_cpy_flt_cuda<float, float> (cx: src0_ddc, cdst: src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, stream: main_stream);
404 }
405 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
406 if (contiguous_srcs) {
407 ggml_cpy_flt_contiguous_cuda<float, nv_bfloat16> (cx: src0_ddc, cdst: src1_ddc, ne, stream: main_stream);
408 } else {
409 ggml_cpy_flt_cuda<float, nv_bfloat16> (cx: src0_ddc, cdst: src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, stream: main_stream);
410 }
411 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
412 if (contiguous_srcs) {
413 ggml_cpy_flt_contiguous_cuda<float, half> (cx: src0_ddc, cdst: src1_ddc, ne, stream: main_stream);
414 } else {
415 ggml_cpy_flt_cuda<float, half> (cx: src0_ddc, cdst: src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, stream: main_stream);
416 }
417 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
418 ggml_cpy_f32_q8_0_cuda(cx: src0_ddc, cdst: src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, stream: main_stream);
419 } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
420 ggml_cpy_q8_0_f32_cuda(cx: src0_ddc, cdst: src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, stream: main_stream);
421 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
422 ggml_cpy_f32_q4_0_cuda(cx: src0_ddc, cdst: src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, stream: main_stream);
423 } else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) {
424 ggml_cpy_q4_0_f32_cuda(cx: src0_ddc, cdst: src1_ddc, ne, ne00, ne01, ne02,
425 nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, stream: main_stream);
426 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
427 ggml_cpy_f32_q4_1_cuda(cx: src0_ddc, cdst: src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, stream: main_stream);
428 } else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) {
429 ggml_cpy_q4_1_f32_cuda(cx: src0_ddc, cdst: src1_ddc, ne, ne00, ne01, ne02,
430 nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, stream: main_stream);
431 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
432 ggml_cpy_f32_q5_0_cuda(cx: src0_ddc, cdst: src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, stream: main_stream);
433 } else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) {
434 ggml_cpy_q5_0_f32_cuda(cx: src0_ddc, cdst: src1_ddc, ne, ne00, ne01, ne02,
435 nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, stream: main_stream);
436 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
437 ggml_cpy_f32_iq4_nl_cuda(cx: src0_ddc, cdst: src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, stream: main_stream);
438 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
439 ggml_cpy_f32_q5_1_cuda(cx: src0_ddc, cdst: src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, stream: main_stream);
440 } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
441 ggml_cpy_q5_1_f32_cuda(cx: src0_ddc, cdst: src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, stream: main_stream);
442 } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
443 if (can_be_transposed) {
444 ggml_cpy_flt_cuda<half, half, true> (cx: src0_ddc, cdst: src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, stream: main_stream);
445 } else {
446 ggml_cpy_flt_cuda<half, half> (cx: src0_ddc, cdst: src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, stream: main_stream);
447 }
448 } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) {
449 if (contiguous_srcs) {
450 ggml_cpy_flt_contiguous_cuda<half, nv_bfloat16> (cx: src0_ddc, cdst: src1_ddc, ne, stream: main_stream);
451 } else {
452 ggml_cpy_flt_cuda<half, nv_bfloat16> (cx: src0_ddc, cdst: src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, stream: main_stream);
453 }
454 } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
455 if (contiguous_srcs) {
456 ggml_cpy_flt_contiguous_cuda<half, float> (cx: src0_ddc, cdst: src1_ddc, ne, stream: main_stream);
457 } else {
458 ggml_cpy_flt_cuda<half, float> (cx: src0_ddc, cdst: src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, stream: main_stream);
459 }
460 } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
461 if (can_be_transposed) {
462 ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16, true> (cx: src0_ddc, cdst: src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, stream: main_stream);
463 } else {
464 ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16> (cx: src0_ddc, cdst: src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, stream: main_stream);
465 }
466 } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {
467 if (contiguous_srcs) {
468 ggml_cpy_flt_contiguous_cuda<nv_bfloat16, half> (cx: src0_ddc, cdst: src1_ddc, ne, stream: main_stream);
469 } else {
470 ggml_cpy_flt_cuda<nv_bfloat16, half> (cx: src0_ddc, cdst: src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, stream: main_stream);
471 }
472 } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
473 if (contiguous_srcs) {
474 ggml_cpy_flt_contiguous_cuda<nv_bfloat16, float> (cx: src0_ddc, cdst: src1_ddc, ne, stream: main_stream);
475 } else {
476 ggml_cpy_flt_cuda<nv_bfloat16, float> (cx: src0_ddc, cdst: src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, stream: main_stream);
477 }
478 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32) {
479 if (contiguous_srcs) {
480 ggml_cpy_flt_contiguous_cuda<float, int32_t> (cx: src0_ddc, cdst: src1_ddc, ne, stream: main_stream);
481 } else {
482 ggml_cpy_flt_cuda<float, int32_t> (cx: src0_ddc, cdst: src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, stream: main_stream);
483 }
484 } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_F32) {
485 if (contiguous_srcs) {
486 ggml_cpy_flt_contiguous_cuda<int32_t, float> (cx: src0_ddc, cdst: src1_ddc, ne, stream: main_stream);
487 } else {
488 ggml_cpy_flt_cuda<int32_t, float> (cx: src0_ddc, cdst: src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, stream: main_stream);
489 }
490 } else {
491 GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
492 ggml_type_name(src0->type), ggml_type_name(src1->type));
493 }
494}
495
496void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
497 const ggml_tensor * src0 = dst->src[0];
498 ggml_cuda_cpy(ctx, src0, dst);
499}
500