1#include "add-id.cuh"
2
3static __global__ void add_id_kernel(
4 const float * src0, const float * src1, const int32_t * src2, float * dst,
5 int64_t ne0, int64_t ne1,
6 size_t nb01, size_t nb02,
7 size_t nb11,
8 size_t nb21
9 ) {
10
11 const int64_t i1 = blockIdx.x;
12 const int64_t i2 = blockIdx.y;
13
14 const int i11 = *(const int32_t *) ((const char *) src2 + i1*sizeof(int32_t) + i2*nb21);
15
16 const size_t nb1 = ne0 * sizeof(float);
17 const size_t nb2 = ne1 * nb1;
18
19 float * dst_row = (float *)((char *)dst + i1*nb1 + i2*nb2);
20 const float * src0_row = (const float *)((const char *)src0 + i1*nb01 + i2*nb02);
21 const float * src1_row = (const float *)((const char *)src1 + i11*nb11);
22
23 for (int64_t i0 = threadIdx.x; i0 < ne0; i0 += blockDim.x) {
24 dst_row[i0] = src0_row[i0] + src1_row[i0];
25 }
26}
27
28void ggml_cuda_op_add_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
29 const ggml_tensor * src0 = dst->src[0];
30 const ggml_tensor * src1 = dst->src[1];
31 const ggml_tensor * src2 = dst->src[2];
32
33 GGML_TENSOR_TERNARY_OP_LOCALS
34
35 GGML_ASSERT(dst->type == GGML_TYPE_F32);
36 GGML_ASSERT(src0->type == GGML_TYPE_F32);
37 GGML_ASSERT(src1->type == GGML_TYPE_F32);
38 GGML_ASSERT(src2->type == GGML_TYPE_I32);
39
40 GGML_ASSERT(nb00 == sizeof(float));
41 GGML_ASSERT(nb10 == sizeof(float));
42 GGML_ASSERT(nb20 == sizeof(int32_t));
43
44 const float * src0_d = (const float *)src0->data;
45 const float * src1_d = (const float *)src1->data;
46 const int32_t * src2_d = (const int32_t *)src2->data;
47 float * dst_d = (float *)dst->data;
48
49 int threads = std::min((int)ne00, 768); // cols
50 dim3 blocks(ne01, ne02); // n_experts_used, n_tokens
51 add_id_kernel<<<blocks, threads, 0, ctx.stream()>>>(
52 src0_d, src1_d, src2_d, dst_d,
53 ne0, ne1,
54 nb01, nb02,
55 nb11,
56 nb21
57 );
58}
59