| 1 | #include "set.cuh" |
| 2 | #include "cpy.cuh" |
| 3 | |
| 4 | void ggml_cuda_op_set(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
| 5 | const ggml_tensor * src0 = dst->src[0]; |
| 6 | const ggml_tensor * src1 = dst->src[1]; |
| 7 | |
| 8 | GGML_ASSERT((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32)); |
| 9 | GGML_ASSERT(src1->type == src0->type); |
| 10 | GGML_ASSERT(dst ->type == src0->type); |
| 11 | |
| 12 | GGML_ASSERT(ggml_is_contiguous(dst)); |
| 13 | GGML_ASSERT(ggml_is_contiguous(src0)); |
| 14 | GGML_ASSERT(ggml_is_contiguous(src1)); |
| 15 | |
| 16 | const size_t nb1 = ((int32_t *) dst->op_params)[0]; |
| 17 | const size_t nb2 = ((int32_t *) dst->op_params)[1]; |
| 18 | const size_t nb3 = ((int32_t *) dst->op_params)[2]; |
| 19 | const size_t offset = ((int32_t *) dst->op_params)[3]; |
| 20 | const bool inplace= (bool) ((int32_t *) dst->op_params)[4]; |
| 21 | |
| 22 | if (!inplace) { |
| 23 | ggml_cuda_cpy(ctx, src0, dst); |
| 24 | } |
| 25 | |
| 26 | ggml_tensor dst_view = *dst; |
| 27 | dst_view.data = (void *)((char *)dst->data + offset); |
| 28 | dst_view.ne[0] = src1->ne[0]; |
| 29 | dst_view.ne[1] = src1->ne[1]; |
| 30 | dst_view.ne[2] = src1->ne[2]; |
| 31 | dst_view.ne[3] = src1->ne[3]; |
| 32 | |
| 33 | dst_view.nb[0] = ggml_element_size(dst); |
| 34 | dst_view.nb[1] = nb1; |
| 35 | dst_view.nb[2] = nb2; |
| 36 | dst_view.nb[3] = nb3; |
| 37 | |
| 38 | ggml_cuda_cpy(ctx, src1, &dst_view); |
| 39 | } |
| 40 | |