1#include "set.cuh"
2#include "cpy.cuh"
3
4void ggml_cuda_op_set(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
5 const ggml_tensor * src0 = dst->src[0];
6 const ggml_tensor * src1 = dst->src[1];
7
8 GGML_ASSERT((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32));
9 GGML_ASSERT(src1->type == src0->type);
10 GGML_ASSERT(dst ->type == src0->type);
11
12 GGML_ASSERT(ggml_is_contiguous(dst));
13 GGML_ASSERT(ggml_is_contiguous(src0));
14 GGML_ASSERT(ggml_is_contiguous(src1));
15
16 const size_t nb1 = ((int32_t *) dst->op_params)[0];
17 const size_t nb2 = ((int32_t *) dst->op_params)[1];
18 const size_t nb3 = ((int32_t *) dst->op_params)[2];
19 const size_t offset = ((int32_t *) dst->op_params)[3];
20 const bool inplace= (bool) ((int32_t *) dst->op_params)[4];
21
22 if (!inplace) {
23 ggml_cuda_cpy(ctx, src0, dst);
24 }
25
26 ggml_tensor dst_view = *dst;
27 dst_view.data = (void *)((char *)dst->data + offset);
28 dst_view.ne[0] = src1->ne[0];
29 dst_view.ne[1] = src1->ne[1];
30 dst_view.ne[2] = src1->ne[2];
31 dst_view.ne[3] = src1->ne[3];
32
33 dst_view.nb[0] = ggml_element_size(dst);
34 dst_view.nb[1] = nb1;
35 dst_view.nb[2] = nb2;
36 dst_view.nb[3] = nb3;
37
38 ggml_cuda_cpy(ctx, src1, &dst_view);
39}
40