| 1 | #include "argsort.cuh" |
| 2 | |
| 3 | #ifdef GGML_CUDA_USE_CUB |
| 4 | # include <cub/cub.cuh> |
| 5 | using namespace cub; |
| 6 | #endif // GGML_CUDA_USE_CUB |
| 7 | |
| 8 | static __global__ void init_indices(int * indices, const int ncols, const int nrows) { |
| 9 | const int col = blockIdx.x * blockDim.x + threadIdx.x; |
| 10 | const int row = blockIdx.y; |
| 11 | |
| 12 | if (col < ncols && row < nrows) { |
| 13 | indices[row * ncols + col] = col; |
| 14 | } |
| 15 | } |
| 16 | |
| 17 | static __global__ void init_offsets(int * offsets, const int ncols, const int nrows) { |
| 18 | const int idx = blockIdx.x * blockDim.x + threadIdx.x; |
| 19 | if (idx <= nrows) { |
| 20 | offsets[idx] = idx * ncols; |
| 21 | } |
| 22 | } |
| 23 | |
| 24 | #ifdef GGML_CUDA_USE_CUB |
| 25 | static void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool, |
| 26 | const float * x, |
| 27 | int * dst, |
| 28 | const int ncols, |
| 29 | const int nrows, |
| 30 | ggml_sort_order order, |
| 31 | cudaStream_t stream) { |
| 32 | ggml_cuda_pool_alloc<int> temp_indices_alloc(pool, ncols * nrows); |
| 33 | ggml_cuda_pool_alloc<float> temp_keys_alloc(pool, ncols * nrows); |
| 34 | ggml_cuda_pool_alloc<int> offsets_alloc(pool, nrows + 1); |
| 35 | |
| 36 | int * temp_indices = temp_indices_alloc.get(); |
| 37 | float * temp_keys = temp_keys_alloc.get(); |
| 38 | int * d_offsets = offsets_alloc.get(); |
| 39 | |
| 40 | static const int block_size = 256; |
| 41 | const dim3 grid_size((ncols + block_size - 1) / block_size, nrows); |
| 42 | init_indices<<<gridDim: grid_size, blockDim: block_size, sharedMem: 0, stream>>>(indices: temp_indices, ncols, nrows); |
| 43 | |
| 44 | const dim3 offset_grid((nrows + block_size - 1) / block_size); |
| 45 | init_offsets<<<gridDim: offset_grid, blockDim: block_size, sharedMem: 0, stream>>>(offsets: d_offsets, ncols, nrows); |
| 46 | |
| 47 | cudaMemcpyAsync(dst: temp_keys, src: x, count: ncols * nrows * sizeof(float), kind: cudaMemcpyDeviceToDevice, stream); |
| 48 | |
| 49 | size_t temp_storage_bytes = 0; |
| 50 | |
| 51 | if (order == GGML_SORT_ORDER_ASC) { |
| 52 | DeviceSegmentedRadixSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) |
| 53 | temp_indices, dst, // values (indices) |
| 54 | ncols * nrows, nrows, // num items, num segments |
| 55 | d_offsets, d_offsets + 1, 0, sizeof(float) * 8, // all bits |
| 56 | stream); |
| 57 | } else { |
| 58 | DeviceSegmentedRadixSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices, |
| 59 | dst, ncols * nrows, nrows, d_offsets, d_offsets + 1, 0, |
| 60 | sizeof(float) * 8, stream); |
| 61 | } |
| 62 | |
| 63 | ggml_cuda_pool_alloc<uint8_t> temp_storage_alloc(pool, temp_storage_bytes); |
| 64 | void * d_temp_storage = temp_storage_alloc.get(); |
| 65 | |
| 66 | if (order == GGML_SORT_ORDER_ASC) { |
| 67 | DeviceSegmentedRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst, |
| 68 | ncols * nrows, nrows, d_offsets, d_offsets + 1, 0, sizeof(float) * 8, |
| 69 | stream); |
| 70 | } else { |
| 71 | DeviceSegmentedRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, |
| 72 | temp_indices, dst, ncols * nrows, nrows, d_offsets, d_offsets + 1, |
| 73 | 0, sizeof(float) * 8, stream); |
| 74 | } |
| 75 | } |
| 76 | #endif // GGML_CUDA_USE_CUB |
| 77 | |
| 78 | // Bitonic sort implementation |
| 79 | template<typename T> |
| 80 | static inline __device__ void ggml_cuda_swap(T & a, T & b) { |
| 81 | T tmp = a; |
| 82 | a = b; |
| 83 | b = tmp; |
| 84 | } |
| 85 | |
| 86 | template<ggml_sort_order order> |
| 87 | static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int ncols, int ncols_pad) { |
| 88 | // bitonic sort |
| 89 | int col = threadIdx.x; |
| 90 | int row = blockIdx.x; |
| 91 | |
| 92 | if (col >= ncols_pad) { |
| 93 | return; |
| 94 | } |
| 95 | |
| 96 | const float * x_row = x + row * ncols; |
| 97 | extern __shared__ int dst_row[]; |
| 98 | |
| 99 | // initialize indices |
| 100 | dst_row[col] = col; |
| 101 | |
| 102 | __syncthreads(); |
| 103 | |
| 104 | for (int k = 2; k <= ncols_pad; k *= 2) { |
| 105 | for (int j = k / 2; j > 0; j /= 2) { |
| 106 | int ixj = col ^ j; |
| 107 | if (ixj > col) { |
| 108 | if ((col & k) == 0) { |
| 109 | if (dst_row[col] >= ncols || |
| 110 | (dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ? |
| 111 | x_row[dst_row[col]] > x_row[dst_row[ixj]] : |
| 112 | x_row[dst_row[col]] < x_row[dst_row[ixj]])) |
| 113 | ) { |
| 114 | ggml_cuda_swap(a&: dst_row[col], b&: dst_row[ixj]); |
| 115 | } |
| 116 | } else { |
| 117 | if (dst_row[ixj] >= ncols || |
| 118 | (dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ? |
| 119 | x_row[dst_row[col]] < x_row[dst_row[ixj]] : |
| 120 | x_row[dst_row[col]] > x_row[dst_row[ixj]])) |
| 121 | ) { |
| 122 | ggml_cuda_swap(a&: dst_row[col], b&: dst_row[ixj]); |
| 123 | } |
| 124 | } |
| 125 | } |
| 126 | __syncthreads(); |
| 127 | } |
| 128 | } |
| 129 | |
| 130 | // copy the result to dst without the padding |
| 131 | if (col < ncols) { |
| 132 | dst[row * ncols + col] = dst_row[col]; |
| 133 | } |
| 134 | } |
| 135 | |
| 136 | static int next_power_of_2(int x) { |
| 137 | int n = 1; |
| 138 | while (n < x) { |
| 139 | n *= 2; |
| 140 | } |
| 141 | return n; |
| 142 | } |
| 143 | |
| 144 | static void argsort_f32_i32_cuda_bitonic(const float * x, |
| 145 | int * dst, |
| 146 | const int ncols, |
| 147 | const int nrows, |
| 148 | ggml_sort_order order, |
| 149 | cudaStream_t stream) { |
| 150 | // bitonic sort requires ncols to be power of 2 |
| 151 | const int ncols_pad = next_power_of_2(x: ncols); |
| 152 | |
| 153 | const dim3 block_dims(ncols_pad, 1, 1); |
| 154 | const dim3 block_nums(nrows, 1, 1); |
| 155 | const size_t shared_mem = ncols_pad * sizeof(int); |
| 156 | |
| 157 | // FIXME: this limit could be raised by ~2-4x on Ampere or newer |
| 158 | GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb); |
| 159 | |
| 160 | if (order == GGML_SORT_ORDER_ASC) { |
| 161 | k_argsort_f32_i32<GGML_SORT_ORDER_ASC> |
| 162 | <<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad); |
| 163 | } else if (order == GGML_SORT_ORDER_DESC) { |
| 164 | k_argsort_f32_i32<GGML_SORT_ORDER_DESC> |
| 165 | <<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad); |
| 166 | } else { |
| 167 | GGML_ABORT("fatal error" ); |
| 168 | } |
| 169 | } |
| 170 | |
| 171 | void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
| 172 | const ggml_tensor * src0 = dst->src[0]; |
| 173 | const float * src0_d = (const float *)src0->data; |
| 174 | float * dst_d = (float *)dst->data; |
| 175 | cudaStream_t stream = ctx.stream(); |
| 176 | |
| 177 | GGML_ASSERT(src0->type == GGML_TYPE_F32); |
| 178 | GGML_ASSERT( dst->type == GGML_TYPE_I32); |
| 179 | GGML_ASSERT(ggml_is_contiguous(src0)); |
| 180 | |
| 181 | const int64_t ncols = src0->ne[0]; |
| 182 | const int64_t nrows = ggml_nrows(src0); |
| 183 | |
| 184 | enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0]; |
| 185 | |
| 186 | #ifdef GGML_CUDA_USE_CUB |
| 187 | const int ncols_pad = next_power_of_2(x: ncols); |
| 188 | const size_t shared_mem = ncols_pad * sizeof(int); |
| 189 | const size_t max_shared_mem = ggml_cuda_info().devices[ggml_cuda_get_device()].smpb; |
| 190 | |
| 191 | if (shared_mem > max_shared_mem || ncols > 1024) { |
| 192 | ggml_cuda_pool & pool = ctx.pool(); |
| 193 | argsort_f32_i32_cuda_cub(pool, src0_d, (int *) dst_d, ncols, nrows, order, stream); |
| 194 | } else { |
| 195 | argsort_f32_i32_cuda_bitonic(src0_d, (int *) dst_d, ncols, nrows, order, stream); |
| 196 | } |
| 197 | #else |
| 198 | argsort_f32_i32_cuda_bitonic(src0_d, (int *) dst_d, ncols, nrows, order, stream); |
| 199 | #endif |
| 200 | } |
| 201 | |