| 1 | #include "add-id.cuh" |
| 2 | |
| 3 | static __global__ void add_id_kernel( |
| 4 | const float * src0, const float * src1, const int32_t * src2, float * dst, |
| 5 | int64_t ne0, int64_t ne1, |
| 6 | size_t nb01, size_t nb02, |
| 7 | size_t nb11, |
| 8 | size_t nb21 |
| 9 | ) { |
| 10 | |
| 11 | const int64_t i1 = blockIdx.x; |
| 12 | const int64_t i2 = blockIdx.y; |
| 13 | |
| 14 | const int i11 = *(const int32_t *) ((const char *) src2 + i1*sizeof(int32_t) + i2*nb21); |
| 15 | |
| 16 | const size_t nb1 = ne0 * sizeof(float); |
| 17 | const size_t nb2 = ne1 * nb1; |
| 18 | |
| 19 | float * dst_row = (float *)((char *)dst + i1*nb1 + i2*nb2); |
| 20 | const float * src0_row = (const float *)((const char *)src0 + i1*nb01 + i2*nb02); |
| 21 | const float * src1_row = (const float *)((const char *)src1 + i11*nb11); |
| 22 | |
| 23 | for (int64_t i0 = threadIdx.x; i0 < ne0; i0 += blockDim.x) { |
| 24 | dst_row[i0] = src0_row[i0] + src1_row[i0]; |
| 25 | } |
| 26 | } |
| 27 | |
| 28 | void ggml_cuda_op_add_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
| 29 | const ggml_tensor * src0 = dst->src[0]; |
| 30 | const ggml_tensor * src1 = dst->src[1]; |
| 31 | const ggml_tensor * src2 = dst->src[2]; |
| 32 | |
| 33 | GGML_TENSOR_TERNARY_OP_LOCALS |
| 34 | |
| 35 | GGML_ASSERT(dst->type == GGML_TYPE_F32); |
| 36 | GGML_ASSERT(src0->type == GGML_TYPE_F32); |
| 37 | GGML_ASSERT(src1->type == GGML_TYPE_F32); |
| 38 | GGML_ASSERT(src2->type == GGML_TYPE_I32); |
| 39 | |
| 40 | GGML_ASSERT(nb00 == sizeof(float)); |
| 41 | GGML_ASSERT(nb10 == sizeof(float)); |
| 42 | GGML_ASSERT(nb20 == sizeof(int32_t)); |
| 43 | |
| 44 | const float * src0_d = (const float *)src0->data; |
| 45 | const float * src1_d = (const float *)src1->data; |
| 46 | const int32_t * src2_d = (const int32_t *)src2->data; |
| 47 | float * dst_d = (float *)dst->data; |
| 48 | |
| 49 | int threads = std::min((int)ne00, 768); // cols |
| 50 | dim3 blocks(ne01, ne02); // n_experts_used, n_tokens |
| 51 | add_id_kernel<<<blocks, threads, 0, ctx.stream()>>>( |
| 52 | src0_d, src1_d, src2_d, dst_d, |
| 53 | ne0, ne1, |
| 54 | nb01, nb02, |
| 55 | nb11, |
| 56 | nb21 |
| 57 | ); |
| 58 | } |
| 59 | |