1#include "pad_reflect_1d.cuh"
2
3static __global__ __launch_bounds__(CUDA_PAD_REFLECT_1D_BLOCK_SIZE, 1) void
4 pad_reflect_1d_kernel_f32(
5 const void * __restrict__ src0,
6 void * __restrict__ dst,
7 const int64_t ne0,
8 const int64_t ne00,
9 const uint3 ne01,
10 const int64_t ne02,
11 const int64_t ne03,
12 const int64_t nb00,
13 const int64_t nb01,
14 const int64_t nb02,
15 const int64_t nb03,
16 const int64_t nb0,
17 const int64_t nb1,
18 const int64_t nb2,
19 const int64_t nb3,
20 const int p0,
21 const int p1) {
22 const int64_t i3 = blockIdx.z;
23 const int64_t i2 = blockIdx.y;
24
25 const uint2 div_mod_packed = fast_div_modulo(n: blockIdx.x, fastdiv_values: ne01);
26 const int64_t tile1 = div_mod_packed.y; // i1
27 const int64_t tile0 = div_mod_packed.x; // nth i0 tile
28 const int64_t i1 = tile1;
29 const int64_t i0 = threadIdx.x + tile0 * blockDim.x;
30
31 // ne01.z is original value of unpacked ne01 (see init_fastdiv_values in common.cuh)
32 if (i0 >= ne0 || i1 >= ne01.z || i2 >= ne02 || i3 >= ne03) {
33 return;
34 }
35
36 const char * src0_ptr = (const char *) src0 + i3 * nb03 + i2 * nb02 + i1 * nb01;
37 char * dst_ptr = (char *) dst + i3 * nb3 + i2 * nb2 + i1 * nb1;
38
39 const int64_t rel_i0 = i0 - p0; // relative i0 in src0
40 int64_t src_idx;
41
42 if (rel_i0 < 0) {
43 // Left padding - reflect
44 src_idx = -rel_i0;
45 } else if (rel_i0 < ne00) {
46 // Middle - copy
47 src_idx = rel_i0;
48 } else {
49 // Right padding - reflect
50 src_idx = 2 * ne00 - 2 - rel_i0;
51 }
52 const float value = *(const float *) (src0_ptr + src_idx * nb00);
53 *(float *) (dst_ptr + i0 * nb0) = value;
54
55 GGML_UNUSED(p1);
56}
57
58void ggml_cuda_op_pad_reflect_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
59 const ggml_tensor * src0 = dst->src[0];
60 cudaStream_t stream = ctx.stream();
61
62 GGML_ASSERT(src0->type == GGML_TYPE_F32);
63 GGML_ASSERT(dst->type == GGML_TYPE_F32);
64
65 const int32_t * opts = (const int32_t *) dst->op_params;
66 const int p0 = opts[0];
67 const int p1 = opts[1];
68
69 const int64_t ne00 = src0->ne[0];
70 const int64_t ne01 = src0->ne[1];
71 const uint3 ne01_packed = init_fastdiv_values(d_64: ne01);
72 const int64_t ne02 = src0->ne[2];
73 const int64_t ne03 = src0->ne[3];
74
75 const int64_t ne0 = dst->ne[0];
76
77 // sanity: padded length matches
78 GGML_ASSERT(ne0 == ne00 + p0 + p1);
79
80 constexpr int64_t bx = CUDA_PAD_REFLECT_1D_BLOCK_SIZE; // threads per block (x)
81 const int64_t tiles0 = (ne0 + bx - 1) / bx; // number of tiles along i0
82 // grid.x covers i1 and all tiles of i0: [ne01 * tiles0]
83 // grid.y covers i2: [ne02]
84 // grid.z covers i3: [ne03]
85 const dim3 grid_dims((unsigned) (ne01 * tiles0), (unsigned) ne02, (unsigned) ne03);
86 const dim3 block_dims((unsigned) bx, 1, 1);
87
88 pad_reflect_1d_kernel_f32<<<gridDim: grid_dims, blockDim: block_dims, sharedMem: 0, stream>>>(
89 src0: src0->data, dst: dst->data, ne0, ne00, ne01: ne01_packed, ne02, ne03, nb00: src0->nb[0], nb01: src0->nb[1], nb02: src0->nb[2], nb03: src0->nb[3],
90 nb0: dst->nb[0], nb1: dst->nb[1], nb2: dst->nb[2], nb3: dst->nb[3], p0, p1);
91}
92