| 1 | #include "conv2d-dw.cuh" |
| 2 | |
| 3 | struct conv_params { |
| 4 | int in_w, in_h; |
| 5 | int out_w, out_h; |
| 6 | int kernel_w, kernel_h; |
| 7 | int stride_x, stride_y; |
| 8 | int padding_x, padding_y; |
| 9 | int dilation_x, dilation_y; |
| 10 | int channels, batches; |
| 11 | }; |
| 12 | |
| 13 | struct kernel_bounds { |
| 14 | int y_min, y_max; |
| 15 | int x_min, x_max; |
| 16 | }; |
| 17 | |
| 18 | __device__ __forceinline__ kernel_bounds calculate_kernel_bounds(int out_x, int out_y, const conv_params & params) { |
| 19 | kernel_bounds bounds; |
| 20 | bounds.y_min = max(a: 0, b: (params.padding_y - out_y * params.stride_y + params.dilation_y - 1) / params.dilation_y); |
| 21 | bounds.y_max = |
| 22 | min(a: params.kernel_h, |
| 23 | b: (params.in_h + params.padding_y - out_y * params.stride_y + params.dilation_y - 1) / params.dilation_y); |
| 24 | bounds.x_min = max(a: 0, b: (params.padding_x - out_x * params.stride_x + params.dilation_x - 1) / params.dilation_x); |
| 25 | bounds.x_max = |
| 26 | min(a: params.kernel_w, |
| 27 | b: (params.in_w + params.padding_x - out_x * params.stride_x + params.dilation_x - 1) / params.dilation_x); |
| 28 | return bounds; |
| 29 | } |
| 30 | |
| 31 | __device__ __forceinline__ int calculate_input_coord(int out_coord, int kern_coord, int stride, int dilation, int padding) { |
| 32 | return out_coord * stride + kern_coord * dilation - padding; |
| 33 | } |
| 34 | |
| 35 | struct whcn_layout { |
| 36 | __device__ static int input_index(int n, int c, int y, int x, const conv_params & params) { |
| 37 | return n * (params.channels * params.in_w * params.in_h) + c * params.in_w * params.in_h + y * params.in_w + x; |
| 38 | } |
| 39 | |
| 40 | __device__ static int kernel_index(int c, int ky, int kx, const conv_params & params) { |
| 41 | return c * params.kernel_h * params.kernel_w + ky * params.kernel_w + kx; |
| 42 | } |
| 43 | |
| 44 | __device__ static int output_index(int n, int c, int y, int x, const conv_params & params) { |
| 45 | return n * (params.channels * params.out_w * params.out_h) + c * params.out_w * params.out_h + |
| 46 | y * params.out_w + x; |
| 47 | } |
| 48 | |
| 49 | __device__ static void unpack_indices(int global_idx, const conv_params & params, int & n, int & c, int & out_y, |
| 50 | int & out_x) { |
| 51 | out_x = global_idx % params.out_w; |
| 52 | out_y = (global_idx / params.out_w) % params.out_h; |
| 53 | c = (global_idx / (params.out_w * params.out_h)) % params.channels; |
| 54 | n = global_idx / (params.out_w * params.out_h * params.channels); |
| 55 | } |
| 56 | }; |
| 57 | |
| 58 | struct cwhn_layout { |
| 59 | __device__ static int input_index(int n, int c, int y, int x, const conv_params & params) { |
| 60 | return n * (params.channels * params.in_w * params.in_h) + (y * params.in_w + x) * params.channels + c; |
| 61 | } |
| 62 | |
| 63 | __device__ static int kernel_index(int c, int ky, int kx, const conv_params & params) { |
| 64 | return (ky * params.kernel_w + kx) * params.channels + c; |
| 65 | } |
| 66 | |
| 67 | __device__ static int output_index(int n, int c, int y, int x, const conv_params & params) { |
| 68 | return n * (params.channels * params.out_w * params.out_h) + y * (params.out_w * params.channels) + |
| 69 | x * params.channels + c; |
| 70 | } |
| 71 | |
| 72 | __device__ static void unpack_indices(int global_idx, const conv_params & params, int & n, int & c, int & out_y, |
| 73 | int & out_x) { |
| 74 | c = global_idx % params.channels; |
| 75 | out_x = (global_idx / params.channels) % params.out_w; |
| 76 | out_y = (global_idx / (params.channels * params.out_w)) % params.out_h; |
| 77 | n = global_idx / (params.channels * params.out_w * params.out_h); |
| 78 | } |
| 79 | }; |
| 80 | |
| 81 | template <typename T, typename Layout> |
| 82 | __global__ void conv2d_dw_kernel(const T * __restrict__ input, const T * __restrict__ kernel, T * __restrict__ output, |
| 83 | const int in_w, const int in_h, const int out_w, const int out_h, |
| 84 | const int kernel_w, const int kernel_h, const int stride_x, const int stride_y, |
| 85 | const int padding_x, const int padding_y, const int dilation_x, const int dilation_y, |
| 86 | const int channels, const int batches) { |
| 87 | const int global_idx = blockIdx.x * blockDim.x + threadIdx.x; |
| 88 | const int total_elements = batches * channels * out_h * out_w; |
| 89 | |
| 90 | if (global_idx >= total_elements) { |
| 91 | return; |
| 92 | } |
| 93 | |
| 94 | conv_params params = { .in_w: in_w, .in_h: in_h, .out_w: out_w, .out_h: out_h, .kernel_w: kernel_w, .kernel_h: kernel_h, .stride_x: stride_x, |
| 95 | .stride_y: stride_y, .padding_x: padding_x, .padding_y: padding_y, .dilation_x: dilation_x, .dilation_y: dilation_y, .channels: channels, .batches: batches }; |
| 96 | |
| 97 | int batch_idx, channel_idx, out_y_idx, out_x_idx; |
| 98 | Layout::unpack_indices(global_idx, params, batch_idx, channel_idx, out_y_idx, out_x_idx); |
| 99 | |
| 100 | T accumulator = 0; |
| 101 | kernel_bounds bounds = calculate_kernel_bounds(out_x: out_x_idx, out_y: out_y_idx, params); |
| 102 | |
| 103 | for (int kern_y = bounds.y_min; kern_y < bounds.y_max; ++kern_y) { |
| 104 | int in_y_idx = calculate_input_coord(out_coord: out_y_idx, kern_coord: kern_y, stride: params.stride_y, dilation: params.dilation_y, padding: params.padding_y); |
| 105 | |
| 106 | for (int kern_x = bounds.x_min; kern_x < bounds.x_max; ++kern_x) { |
| 107 | int in_x_idx = calculate_input_coord(out_coord: out_x_idx, kern_coord: kern_x, stride: params.stride_x, dilation: params.dilation_x, padding: params.padding_x); |
| 108 | |
| 109 | const T input_val = input[Layout::input_index(batch_idx, channel_idx, in_y_idx, in_x_idx, params)]; |
| 110 | const T kernel_val = kernel[Layout::kernel_index(channel_idx, kern_y, kern_x, params)]; |
| 111 | |
| 112 | accumulator += input_val * kernel_val; |
| 113 | } |
| 114 | } |
| 115 | |
| 116 | output[Layout::output_index(batch_idx, channel_idx, out_y_idx, out_x_idx, params)] = accumulator; |
| 117 | } |
| 118 | |
| 119 | void ggml_cuda_op_conv2d_dw(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
| 120 | const ggml_tensor * kernel = dst->src[0]; |
| 121 | const ggml_tensor * input = dst->src[1]; |
| 122 | |
| 123 | GGML_ASSERT(kernel->type == GGML_TYPE_F32 && input->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32); |
| 124 | const float * w_d = (const float *) kernel->data; |
| 125 | const float * x_d = (const float *) input->data; |
| 126 | float * y_d = (float *) dst->data; |
| 127 | |
| 128 | const int32_t * p = (const int32_t *) dst->op_params; |
| 129 | const int stride_x = p[0]; |
| 130 | const int stride_y = p[1]; |
| 131 | const int padding_x = p[2]; |
| 132 | const int padding_y = p[3]; |
| 133 | const int dilation_x = p[4]; |
| 134 | const int dilation_y = p[5]; |
| 135 | |
| 136 | const int in_w = input->ne[0]; |
| 137 | const int in_h = input->ne[1]; |
| 138 | const int kernel_w = kernel->ne[0]; |
| 139 | const int kernel_h = kernel->ne[1]; |
| 140 | const int out_w = dst->ne[0]; |
| 141 | const int out_h = dst->ne[1]; |
| 142 | const int channels = dst->ne[2]; |
| 143 | const int batches = dst->ne[3]; |
| 144 | |
| 145 | cudaStream_t st = ctx.stream(); |
| 146 | |
| 147 | const int total = batches * channels * out_h * out_w; |
| 148 | const int blocks = (total + CUDA_CONV2D_DW_BLOCK_SIZE - 1) / CUDA_CONV2D_DW_BLOCK_SIZE; |
| 149 | |
| 150 | if (ggml_is_contiguous(input)) { |
| 151 | conv2d_dw_kernel<float, whcn_layout><<<gridDim: blocks, CUDA_CONV2D_DW_BLOCK_SIZE, sharedMem: 0, stream: st>>>( |
| 152 | input: x_d, kernel: w_d, output: y_d, in_w, in_h, out_w, out_h, kernel_w, kernel_h, stride_x, stride_y, padding_x, padding_y, |
| 153 | dilation_x, dilation_y, channels, batches); |
| 154 | } else if (ggml_is_contiguous_channels(input)) { |
| 155 | conv2d_dw_kernel<float, cwhn_layout><<<gridDim: blocks, CUDA_CONV2D_DW_BLOCK_SIZE, sharedMem: 0, stream: st>>>( |
| 156 | input: x_d, kernel: w_d, output: y_d, in_w, in_h, out_w, out_h, kernel_w, kernel_h, stride_x, stride_y, padding_x, padding_y, |
| 157 | dilation_x, dilation_y, channels, batches); |
| 158 | } else { |
| 159 | GGML_ABORT("Unsupported memory layout for conv_2d_dw" ); |
| 160 | } |
| 161 | } |
| 162 | |