1#include "conv2d-dw.cuh"
2
3struct conv_params {
4 int in_w, in_h;
5 int out_w, out_h;
6 int kernel_w, kernel_h;
7 int stride_x, stride_y;
8 int padding_x, padding_y;
9 int dilation_x, dilation_y;
10 int channels, batches;
11};
12
13struct kernel_bounds {
14 int y_min, y_max;
15 int x_min, x_max;
16};
17
18__device__ __forceinline__ kernel_bounds calculate_kernel_bounds(int out_x, int out_y, const conv_params & params) {
19 kernel_bounds bounds;
20 bounds.y_min = max(a: 0, b: (params.padding_y - out_y * params.stride_y + params.dilation_y - 1) / params.dilation_y);
21 bounds.y_max =
22 min(a: params.kernel_h,
23 b: (params.in_h + params.padding_y - out_y * params.stride_y + params.dilation_y - 1) / params.dilation_y);
24 bounds.x_min = max(a: 0, b: (params.padding_x - out_x * params.stride_x + params.dilation_x - 1) / params.dilation_x);
25 bounds.x_max =
26 min(a: params.kernel_w,
27 b: (params.in_w + params.padding_x - out_x * params.stride_x + params.dilation_x - 1) / params.dilation_x);
28 return bounds;
29}
30
31__device__ __forceinline__ int calculate_input_coord(int out_coord, int kern_coord, int stride, int dilation, int padding) {
32 return out_coord * stride + kern_coord * dilation - padding;
33}
34
35struct whcn_layout {
36 __device__ static int input_index(int n, int c, int y, int x, const conv_params & params) {
37 return n * (params.channels * params.in_w * params.in_h) + c * params.in_w * params.in_h + y * params.in_w + x;
38 }
39
40 __device__ static int kernel_index(int c, int ky, int kx, const conv_params & params) {
41 return c * params.kernel_h * params.kernel_w + ky * params.kernel_w + kx;
42 }
43
44 __device__ static int output_index(int n, int c, int y, int x, const conv_params & params) {
45 return n * (params.channels * params.out_w * params.out_h) + c * params.out_w * params.out_h +
46 y * params.out_w + x;
47 }
48
49 __device__ static void unpack_indices(int global_idx, const conv_params & params, int & n, int & c, int & out_y,
50 int & out_x) {
51 out_x = global_idx % params.out_w;
52 out_y = (global_idx / params.out_w) % params.out_h;
53 c = (global_idx / (params.out_w * params.out_h)) % params.channels;
54 n = global_idx / (params.out_w * params.out_h * params.channels);
55 }
56};
57
58struct cwhn_layout {
59 __device__ static int input_index(int n, int c, int y, int x, const conv_params & params) {
60 return n * (params.channels * params.in_w * params.in_h) + (y * params.in_w + x) * params.channels + c;
61 }
62
63 __device__ static int kernel_index(int c, int ky, int kx, const conv_params & params) {
64 return (ky * params.kernel_w + kx) * params.channels + c;
65 }
66
67 __device__ static int output_index(int n, int c, int y, int x, const conv_params & params) {
68 return n * (params.channels * params.out_w * params.out_h) + y * (params.out_w * params.channels) +
69 x * params.channels + c;
70 }
71
72 __device__ static void unpack_indices(int global_idx, const conv_params & params, int & n, int & c, int & out_y,
73 int & out_x) {
74 c = global_idx % params.channels;
75 out_x = (global_idx / params.channels) % params.out_w;
76 out_y = (global_idx / (params.channels * params.out_w)) % params.out_h;
77 n = global_idx / (params.channels * params.out_w * params.out_h);
78 }
79};
80
81template <typename T, typename Layout>
82__global__ void conv2d_dw_kernel(const T * __restrict__ input, const T * __restrict__ kernel, T * __restrict__ output,
83 const int in_w, const int in_h, const int out_w, const int out_h,
84 const int kernel_w, const int kernel_h, const int stride_x, const int stride_y,
85 const int padding_x, const int padding_y, const int dilation_x, const int dilation_y,
86 const int channels, const int batches) {
87 const int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
88 const int total_elements = batches * channels * out_h * out_w;
89
90 if (global_idx >= total_elements) {
91 return;
92 }
93
94 conv_params params = { .in_w: in_w, .in_h: in_h, .out_w: out_w, .out_h: out_h, .kernel_w: kernel_w, .kernel_h: kernel_h, .stride_x: stride_x,
95 .stride_y: stride_y, .padding_x: padding_x, .padding_y: padding_y, .dilation_x: dilation_x, .dilation_y: dilation_y, .channels: channels, .batches: batches };
96
97 int batch_idx, channel_idx, out_y_idx, out_x_idx;
98 Layout::unpack_indices(global_idx, params, batch_idx, channel_idx, out_y_idx, out_x_idx);
99
100 T accumulator = 0;
101 kernel_bounds bounds = calculate_kernel_bounds(out_x: out_x_idx, out_y: out_y_idx, params);
102
103 for (int kern_y = bounds.y_min; kern_y < bounds.y_max; ++kern_y) {
104 int in_y_idx = calculate_input_coord(out_coord: out_y_idx, kern_coord: kern_y, stride: params.stride_y, dilation: params.dilation_y, padding: params.padding_y);
105
106 for (int kern_x = bounds.x_min; kern_x < bounds.x_max; ++kern_x) {
107 int in_x_idx = calculate_input_coord(out_coord: out_x_idx, kern_coord: kern_x, stride: params.stride_x, dilation: params.dilation_x, padding: params.padding_x);
108
109 const T input_val = input[Layout::input_index(batch_idx, channel_idx, in_y_idx, in_x_idx, params)];
110 const T kernel_val = kernel[Layout::kernel_index(channel_idx, kern_y, kern_x, params)];
111
112 accumulator += input_val * kernel_val;
113 }
114 }
115
116 output[Layout::output_index(batch_idx, channel_idx, out_y_idx, out_x_idx, params)] = accumulator;
117}
118
119void ggml_cuda_op_conv2d_dw(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
120 const ggml_tensor * kernel = dst->src[0];
121 const ggml_tensor * input = dst->src[1];
122
123 GGML_ASSERT(kernel->type == GGML_TYPE_F32 && input->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
124 const float * w_d = (const float *) kernel->data;
125 const float * x_d = (const float *) input->data;
126 float * y_d = (float *) dst->data;
127
128 const int32_t * p = (const int32_t *) dst->op_params;
129 const int stride_x = p[0];
130 const int stride_y = p[1];
131 const int padding_x = p[2];
132 const int padding_y = p[3];
133 const int dilation_x = p[4];
134 const int dilation_y = p[5];
135
136 const int in_w = input->ne[0];
137 const int in_h = input->ne[1];
138 const int kernel_w = kernel->ne[0];
139 const int kernel_h = kernel->ne[1];
140 const int out_w = dst->ne[0];
141 const int out_h = dst->ne[1];
142 const int channels = dst->ne[2];
143 const int batches = dst->ne[3];
144
145 cudaStream_t st = ctx.stream();
146
147 const int total = batches * channels * out_h * out_w;
148 const int blocks = (total + CUDA_CONV2D_DW_BLOCK_SIZE - 1) / CUDA_CONV2D_DW_BLOCK_SIZE;
149
150 if (ggml_is_contiguous(input)) {
151 conv2d_dw_kernel<float, whcn_layout><<<gridDim: blocks, CUDA_CONV2D_DW_BLOCK_SIZE, sharedMem: 0, stream: st>>>(
152 input: x_d, kernel: w_d, output: y_d, in_w, in_h, out_w, out_h, kernel_w, kernel_h, stride_x, stride_y, padding_x, padding_y,
153 dilation_x, dilation_y, channels, batches);
154 } else if (ggml_is_contiguous_channels(input)) {
155 conv2d_dw_kernel<float, cwhn_layout><<<gridDim: blocks, CUDA_CONV2D_DW_BLOCK_SIZE, sharedMem: 0, stream: st>>>(
156 input: x_d, kernel: w_d, output: y_d, in_w, in_h, out_w, out_h, kernel_w, kernel_h, stride_x, stride_y, padding_x, padding_y,
157 dilation_x, dilation_y, channels, batches);
158 } else {
159 GGML_ABORT("Unsupported memory layout for conv_2d_dw");
160 }
161}
162