1#include "binbcast.cuh"
2#include <cstdint>
3#include <utility>
4
5static __device__ __forceinline__ float op_repeat(const float a, const float b) {
6 return b;
7 GGML_UNUSED(a);
8}
9
10static __device__ __forceinline__ float op_add(const float a, const float b) {
11 return a + b;
12}
13
14static __device__ __forceinline__ float op_sub(const float a, const float b) {
15 return a - b;
16}
17
18static __device__ __forceinline__ float op_mul(const float a, const float b) {
19 return a * b;
20}
21
22static __device__ __forceinline__ float op_div(const float a, const float b) {
23 return a / b;
24}
25
26template <float (*bin_op)(const float, const float),
27 typename src0_t,
28 typename src1_t,
29 typename dst_t,
30 typename... src1_ptrs>
31static __global__ void k_bin_bcast(const src0_t * src0,
32 const src1_t * src1,
33 dst_t * dst,
34 const int ne0,
35 const int ne1,
36 const int ne2,
37 const uint3 ne3,
38 const uint3 ne10,
39 const uint3 ne11,
40 const uint3 ne12,
41 const uint3 ne13,
42 /*int s0, */ const int s1,
43 const int s2,
44 const int s3,
45 /*int s00,*/ const int s01,
46 const int s02,
47 const int s03,
48 /*int s10,*/ const int s11,
49 const int s12,
50 const int s13,
51 src1_ptrs... src1s) {
52 const uint32_t i0s = blockDim.x * blockIdx.x + threadIdx.x;
53 const uint32_t i1 = (blockDim.y * blockIdx.y + threadIdx.y);
54 const uint32_t i2 = fastdiv(n: (blockDim.z * blockIdx.z + threadIdx.z), fastdiv_values: ne3);
55 const uint32_t i3 = (blockDim.z * blockIdx.z + threadIdx.z) - (i2 * ne3.z);
56
57 if (i0s >= (uint32_t)ne0 || i1 >= (uint32_t)ne1 || i2 >= (uint32_t)ne2 || i3 >= ne3.z) {
58 return;
59 }
60
61 const uint32_t i11 = fastmodulo(n: i1, fastdiv_values: ne11);
62 const uint32_t i12 = fastmodulo(n: i2, fastdiv_values: ne12);
63 const uint32_t i13 = fastmodulo(n: i3, fastdiv_values: ne13);
64
65 const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
66 const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
67 const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
68
69 const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr;
70 dst_t * dst_row = dst + i_dst;
71
72 for (int i0 = i0s; i0 < ne0; i0 += blockDim.x * gridDim.x) {
73 const uint32_t i10 = fastmodulo(n: i0, fastdiv_values: ne10);
74
75 float result = src0_row ? (float) src0_row[i0] : 0.0f;
76 if constexpr (sizeof...(src1_ptrs) > 0) {
77 result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10])));
78 } else {
79 result = bin_op(result, (float)src1[i_src1 + i10]);
80 }
81
82 dst_row[i0] = (dst_t) result;
83 }
84}
85
86template <float (*bin_op)(const float, const float),
87 typename src0_t,
88 typename src1_t,
89 typename dst_t,
90 typename... src1_ptrs>
91static __global__ void k_bin_bcast_unravel(const src0_t * src0,
92 const src1_t * src1,
93 dst_t * dst,
94 const uint3 ne0,
95 const uint3 ne1,
96 const uint3 ne2,
97 const uint32_t ne3,
98 const uint3 prod_012,
99 const uint3 prod_01,
100 const uint3 ne10,
101 const uint3 ne11,
102 const uint3 ne12,
103 const uint3 ne13,
104 /*int s0, */ const int s1,
105 const int s2,
106 const int s3,
107 /*int s00,*/ const int s01,
108 const int s02,
109 const int s03,
110 /*int s10,*/ const int s11,
111 const int s12,
112 const int s13,
113 src1_ptrs... src1s) {
114 const int i = blockDim.x*blockIdx.x + threadIdx.x;
115
116 const uint32_t i3 = fastdiv(n: i, fastdiv_values: prod_012);
117 const uint32_t i2 = fastdiv(n: i - i3 * prod_012.z, fastdiv_values: prod_01);
118 const uint32_t i1 = fastdiv(n: i - i3 * prod_012.z - i2 * prod_01.z, fastdiv_values: ne0);
119 const uint32_t i0 = i - i3 * prod_012.z - i2 * prod_01.z - i1 * ne0.z;
120
121 if (i0 >= ne0.z || i1 >= ne1.z || i2 >= ne2.z || i3 >= ne3) {
122 return;
123 }
124
125 const int i11 = fastmodulo(n: i1, fastdiv_values: ne11);
126 const int i12 = fastmodulo(n: i2, fastdiv_values: ne12);
127 const int i13 = fastmodulo(n: i3, fastdiv_values: ne13);
128
129 const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
130 const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
131 const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
132
133 const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr;
134 dst_t * dst_row = dst + i_dst;
135
136 const int i10 = fastmodulo(n: i0, fastdiv_values: ne10);
137
138 float result = src0_row ? (float) src0_row[i0] : 0.0f;
139 if constexpr (sizeof...(src1_ptrs) > 0) {
140 result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10])));
141 } else {
142 result = bin_op(result, (float)src1[i_src1 + i10]);
143 }
144
145 dst_row[i0] = (dst_t) result;
146}
147
148template <float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t, size_t... I>
149static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
150 const src0_t * src0_dd, const src1_t * src1_dd, dst_t * dst_dd,
151 cudaStream_t stream, std::index_sequence<I...>) {
152 GGML_TENSOR_BINARY_OP_LOCALS
153
154 int nr0 = ne10 / ne0;
155 int nr1 = ne11 / ne1;
156 int nr2 = ne12 / ne2;
157 int nr3 = ne13 / ne3;
158
159 int nr[4] = { nr0, nr1, nr2, nr3 };
160
161 int64_t cne[] = { ne0, ne1, ne2, ne3 };
162 int64_t cne0[] = { ne00, ne01, ne02, ne03 };
163 int64_t cne1[] = { ne10, ne11, ne12, ne13 };
164
165 size_t cnb[] = { nb0, nb1, nb2, nb3 };
166 size_t cnb0[] = { nb00, nb01, nb02, nb03 };
167 size_t cnb1[] = { nb10, nb11, nb12, nb13 };
168
169 auto collapse = [](int64_t cne[]) {
170 cne[0] *= cne[1];
171 cne[1] = cne[2];
172 cne[2] = cne[3];
173 cne[3] = 1;
174 };
175
176 auto collapse_nb = [](size_t cnb[], const int64_t cne[]) {
177 cnb[1] *= cne[1];
178 cnb[2] *= cne[2];
179 cnb[3] *= cne[3];
180 };
181
182 if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) {
183 for (int i = 0; i < 4; i++) {
184 if (nr[i] != 1) {
185 break;
186 }
187 if (i > 0) {
188 collapse_nb(cnb, cne);
189 collapse_nb(cnb0, cne0);
190 collapse_nb(cnb1, cne1);
191 collapse(cne);
192 collapse(cne0);
193 collapse(cne1);
194 }
195 }
196 }
197
198 {
199 int64_t ne0 = cne[0];
200 int64_t ne1 = cne[1];
201 int64_t ne2 = cne[2];
202 int64_t ne3 = cne[3];
203
204 //int64_t ne00 = cne0[0]; GGML_UNUSED(ne00);
205 //int64_t ne01 = cne0[1]; GGML_UNUSED(ne01);
206 //int64_t ne02 = cne0[2]; GGML_UNUSED(ne02);
207 //int64_t ne03 = cne0[3]; GGML_UNUSED(ne03);
208
209 size_t nb0 = cnb[0];
210 size_t nb1 = cnb[1];
211 size_t nb2 = cnb[2];
212 size_t nb3 = cnb[3];
213
214 size_t nb00 = cnb0[0];
215 size_t nb01 = cnb0[1];
216 size_t nb02 = cnb0[2];
217 size_t nb03 = cnb0[3];
218
219 size_t nb10 = cnb1[0];
220 size_t nb11 = cnb1[1];
221 size_t nb12 = cnb1[2];
222 size_t nb13 = cnb1[3];
223
224 size_t s0 = nb0 / sizeof(dst_t);
225 size_t s1 = nb1 / sizeof(dst_t);
226 size_t s2 = nb2 / sizeof(dst_t);
227 size_t s3 = nb3 / sizeof(dst_t);
228
229 size_t s10 = nb10 / sizeof(src1_t);
230 size_t s11 = nb11 / sizeof(src1_t);
231 size_t s12 = nb12 / sizeof(src1_t);
232 size_t s13 = nb13 / sizeof(src1_t);
233
234 size_t s00 = nb00 / sizeof(src0_t);
235 size_t s01 = nb01 / sizeof(src0_t);
236 size_t s02 = nb02 / sizeof(src0_t);
237 size_t s03 = nb03 / sizeof(src0_t);
238
239 GGML_ASSERT(nb0 % sizeof(dst_t) == 0);
240 GGML_ASSERT(nb1 % sizeof(dst_t) == 0);
241 GGML_ASSERT(nb2 % sizeof(dst_t) == 0);
242 GGML_ASSERT(nb3 % sizeof(dst_t) == 0);
243
244 GGML_ASSERT(nb00 % sizeof(src0_t) == 0);
245 GGML_ASSERT(nb01 % sizeof(src0_t) == 0);
246 GGML_ASSERT(nb02 % sizeof(src0_t) == 0);
247 GGML_ASSERT(nb03 % sizeof(src0_t) == 0);
248
249 GGML_ASSERT(nb10 % sizeof(src1_t) == 0);
250 GGML_ASSERT(nb11 % sizeof(src1_t) == 0);
251 GGML_ASSERT(nb12 % sizeof(src1_t) == 0);
252 GGML_ASSERT(nb13 % sizeof(src1_t) == 0);
253
254 GGML_ASSERT(s0 == 1);
255 GGML_ASSERT(s00 == 1);
256 GGML_ASSERT(s10 == 1);
257
258 const int block_size = 128;
259
260 int64_t hne0 = std::max(a: ne0 / 2LL, b: 1LL);
261
262 dim3 block_dims;
263 block_dims.x = std::min<unsigned int>(a: hne0, b: block_size);
264 block_dims.y = std::min<unsigned int>(a: ne1, b: block_size / block_dims.x);
265 block_dims.z = std::min(a: std::min<unsigned int>(a: ne2 * ne3, b: block_size / block_dims.x / block_dims.y), b: 64U);
266
267 dim3 block_nums((hne0 + block_dims.x - 1) / block_dims.x, (ne1 + block_dims.y - 1) / block_dims.y,
268 (ne2 * ne3 + block_dims.z - 1) / block_dims.z);
269
270 const uint3 ne10 = init_fastdiv_values(d_64: (uint32_t) cne1[0]);
271 const uint3 ne11 = init_fastdiv_values(d_64: (uint32_t) cne1[1]);
272 const uint3 ne12 = init_fastdiv_values(d_64: (uint32_t) cne1[2]);
273 const uint3 ne13 = init_fastdiv_values(d_64: (uint32_t) cne1[3]);
274
275 if (block_nums.z > 65535 || block_nums.y > 65535) {
276 int block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1) / block_size;
277 const uint3 prod_012 = init_fastdiv_values(d_64: (uint32_t) (ne0 * ne1 * ne2));
278 const uint3 prod_01 = init_fastdiv_values(d_64: (uint32_t) (ne0 * ne1));
279 const uint3 ne0_fastdiv = init_fastdiv_values(d_64: (uint32_t) ne0);
280 const uint3 ne1_fastdiv = init_fastdiv_values(d_64: (uint32_t) ne1);
281 const uint3 ne2_fastdiv = init_fastdiv_values(d_64: (uint32_t) ne2);
282
283 if constexpr (sizeof...(I) > 0) {
284 k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t><<<gridDim: block_num, blockDim: block_size, sharedMem: 0, stream>>>(
285 src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv, ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11,
286 ne12, ne13,
287 /* s0, */ s1, s2, s3,
288 /* s00,*/ s01, s02, s03,
289 /* s10,*/ s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
290 } else {
291 k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t>
292 <<<gridDim: block_num, blockDim: block_size, sharedMem: 0, stream>>>(src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv,
293 ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11, ne12, ne13,
294 /* s0, */ s1, s2, s3,
295 /* s00,*/ s01, s02, s03,
296 /* s10,*/ s11, s12, s13);
297 }
298 } else {
299 const uint3 ne3_fastdiv = init_fastdiv_values(d_64: (uint32_t) ne3);
300 if constexpr (sizeof...(I) > 0) {
301 k_bin_bcast<bin_op, src0_t, src1_t, dst_t><<<gridDim: block_nums, blockDim: block_dims, sharedMem: 0, stream>>>(
302 src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13,
303 /* s0, */ s1, s2, s3,
304 /* s00,*/ s01, s02, s03,
305 /* s10,*/ s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
306 } else {
307 k_bin_bcast<bin_op, src0_t, src1_t, dst_t><<<gridDim: block_nums, blockDim: block_dims, sharedMem: 0, stream>>>(
308 src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13,
309 /* s0, */ s1, s2, s3,
310 /* s00,*/ s01, s02, s03,
311 /* s10,*/ s11, s12, s13);
312 }
313 }
314 }
315}
316
317template <typename T>
318static __global__ void k_repeat_back(
319 const T * __restrict__ src, T * __restrict__ dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
320 const size_t s00, const size_t s01, const size_t s02, const size_t s03,
321 const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3) {
322
323 const int64_t tid0 = int64_t(blockIdx.x)*blockDim.x + threadIdx.x;
324 const int64_t tid1 = int64_t(blockIdx.y)*blockDim.y + threadIdx.y;
325 const int64_t tid23 = int64_t(blockIdx.z)*blockDim.z + threadIdx.z;
326 const int64_t tid2 = tid23 % ne2;
327 const int64_t tid3 = tid23 / ne2;
328
329 if (tid0 >= ne0) {
330 return;
331 }
332
333 T sum = 0;
334 for (int64_t i3 = tid3; i3 < ne03; i3 += ne3) {
335 for (int64_t i2 = tid2; i2 < ne02; i2 += ne2) {
336 for (int64_t i1 = tid1; i1 < ne01; i1 += ne1) {
337 for (int64_t i0 = tid0; i0 < ne00; i0 += ne0) {
338 sum += src[i3*s03 + i2*s02 + i1*s01 + i0*s00];
339 }
340 }
341 }
342 }
343 dst[tid3*ne2*ne1*ne0 + tid2*ne1*ne0 + tid1*ne0 + tid0] = sum;
344}
345
346template <float (*bin_op)(const float, const float), int n_fuse = 1>
347struct bin_bcast_cuda {
348 template<typename src0_t, typename src1_t, typename dst_t>
349 void operator()(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst,
350 const src0_t * src0_dd, const src1_t * src1_dd, dst_t * dst_dd,
351 cudaStream_t stream) {
352 launch_bin_bcast_pack<bin_op, src0_t, src1_t, dst_t>(
353 src0, src1, dst, src0_dd, src1_dd, dst_dd, stream, std::make_index_sequence<n_fuse>{});
354 }
355};
356
357template <typename T>
358static void repeat_back_cuda(
359 const T * src, T * dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
360 const size_t s00, const size_t s01, const size_t s02, const size_t s03,
361 const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) {
362
363 const dim3 block_dims(WARP_SIZE, 1, 1);
364 const dim3 block_nums((ne0 + WARP_SIZE - 1) / WARP_SIZE, ne1, ne2*ne3);
365 k_repeat_back<T><<<gridDim: block_nums, blockDim: block_dims, sharedMem: 0, stream>>>
366 (src, dst, ne00, ne01, ne02, ne03, s00, s01, s02, s03, ne0, ne1, ne2, ne3);
367}
368
369template<class op>
370static void ggml_cuda_op_bin_bcast(
371 const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
372 const void * src0_dd, const void * src1_dd, void * dst_dd, cudaStream_t stream) {
373
374 GGML_ASSERT(src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);
375
376 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
377 op()(src0, src1, dst, (const float *)src0_dd, (const float *)src1_dd, (float *)dst_dd, stream);
378 } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
379 op()(src0, src1, dst, (const half *) src0_dd, (const half *)src1_dd, (half *) dst_dd, stream);
380 } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
381 op()(src0, src1, dst, (const half *) src0_dd, (const float *)src1_dd, (half *) dst_dd, stream);
382 } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
383 op()(src0, src1, dst, (const half *) src0_dd, (const float *)src1_dd, (float *)dst_dd, stream);
384 } else {
385 fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__,
386 ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type));
387 GGML_ABORT("fatal error");
388 }
389}
390
391void ggml_cuda_op_repeat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
392 ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_repeat, 0>>(dst, dst->src[0], dst, nullptr, dst->src[0]->data, dst->data, ctx.stream());
393}
394
395void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
396 ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_add>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
397}
398
399void ggml_cuda_op_sub(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
400 ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_sub>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
401}
402
403void ggml_cuda_op_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
404 ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_mul>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
405}
406
407void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
408 ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_div>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
409}
410
411template <float (*op)(const float, const float), int n_fuse>
412static void ggml_cuda_op_fused_binbcast_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
413 cudaStream_t stream = ctx.stream();
414
415 const ggml_tensor * src0 = dst->src[0];
416 const ggml_tensor * src1 = dst->src[1];
417
418 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
419 launch_bin_bcast_pack<op, float, float, float>(src0, src1, dst,
420 (const float *) src0->data, (const float *) src1->data, (float *) dst->data,
421 stream, std::make_index_sequence<n_fuse>{});
422 } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
423 launch_bin_bcast_pack<op, half, half, half>(src0, src1, dst,
424 (const half *) src0->data, (const half *) src1->data, (half *) dst->data,
425 stream, std::make_index_sequence<n_fuse>{});
426 } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
427 launch_bin_bcast_pack<op, half, float, half>(src0, src1, dst,
428 (const half *) src0->data, (const float *) src1->data, (half *) dst->data,
429 stream, std::make_index_sequence<n_fuse>{});
430 } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
431 launch_bin_bcast_pack<op, half, float, float>(src0, src1, dst,
432 (const half *) src0->data, (const float *) src1->data, (float *) dst->data,
433 stream, std::make_index_sequence<n_fuse>{});
434 } else {
435 fprintf(stderr,
436 "%s: unsupported types for fusion: dst: %s, src0: %s, src1: %s\n",
437 __func__, ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type));
438 GGML_ABORT("fatal error");
439 }
440}
441
442
443void ggml_cuda_op_fused_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst, int n_fuse) {
444 GGML_ASSERT(2 <= n_fuse && n_fuse <= 8);
445
446 switch (n_fuse) {
447 case 2:
448 ggml_cuda_op_fused_binbcast_impl<op_add, 2>(ctx, dst);
449 break;
450 case 3:
451 ggml_cuda_op_fused_binbcast_impl<op_add, 3>(ctx, dst);
452 break;
453 case 4:
454 ggml_cuda_op_fused_binbcast_impl<op_add, 4>(ctx, dst);
455 break;
456 case 5:
457 ggml_cuda_op_fused_binbcast_impl<op_add, 5>(ctx, dst);
458 break;
459 case 6:
460 ggml_cuda_op_fused_binbcast_impl<op_add, 6>(ctx, dst);
461 break;
462 case 7:
463 ggml_cuda_op_fused_binbcast_impl<op_add, 7>(ctx, dst);
464 break;
465 case 8:
466 ggml_cuda_op_fused_binbcast_impl<op_add, 8>(ctx, dst);
467 break;
468 default:
469 GGML_ASSERT(false && "Unsupported n_fuse value");
470 }
471}
472
473void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
474 const ggml_tensor * src0 = dst->src[0];
475
476 GGML_ASSERT(src0->type == dst->type);
477 GGML_ASSERT(ggml_is_contiguous(dst));
478 GGML_ASSERT(ggml_can_repeat(dst, src0));
479
480 cudaStream_t stream = ctx.stream();
481
482 GGML_TENSOR_UNARY_OP_LOCALS;
483
484 GGML_ASSERT(ne2*ne3 <= (1 << 15));
485
486 const size_t ts = ggml_type_size(src0->type);
487 const size_t s00 = nb00 / ts;
488 const size_t s01 = nb01 / ts;
489 const size_t s02 = nb02 / ts;
490 const size_t s03 = nb03 / ts;
491
492 switch (dst->type) {
493 case GGML_TYPE_F32: {
494 const float * src0_d = (const float *) src0->data;
495 float * dst_d = (float *) dst->data;
496 repeat_back_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s00, s01, s02, s03, ne0, ne1, ne2, ne3, stream);
497 } break;
498 default: {
499 GGML_ASSERT(false);
500 } break;
501 }
502}
503