1#include "argsort.cuh"
2
3#ifdef GGML_CUDA_USE_CUB
4# include <cub/cub.cuh>
5using namespace cub;
6#endif // GGML_CUDA_USE_CUB
7
8static __global__ void init_indices(int * indices, const int ncols, const int nrows) {
9 const int col = blockIdx.x * blockDim.x + threadIdx.x;
10 const int row = blockIdx.y;
11
12 if (col < ncols && row < nrows) {
13 indices[row * ncols + col] = col;
14 }
15}
16
17static __global__ void init_offsets(int * offsets, const int ncols, const int nrows) {
18 const int idx = blockIdx.x * blockDim.x + threadIdx.x;
19 if (idx <= nrows) {
20 offsets[idx] = idx * ncols;
21 }
22}
23
24#ifdef GGML_CUDA_USE_CUB
25static void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
26 const float * x,
27 int * dst,
28 const int ncols,
29 const int nrows,
30 ggml_sort_order order,
31 cudaStream_t stream) {
32 ggml_cuda_pool_alloc<int> temp_indices_alloc(pool, ncols * nrows);
33 ggml_cuda_pool_alloc<float> temp_keys_alloc(pool, ncols * nrows);
34 ggml_cuda_pool_alloc<int> offsets_alloc(pool, nrows + 1);
35
36 int * temp_indices = temp_indices_alloc.get();
37 float * temp_keys = temp_keys_alloc.get();
38 int * d_offsets = offsets_alloc.get();
39
40 static const int block_size = 256;
41 const dim3 grid_size((ncols + block_size - 1) / block_size, nrows);
42 init_indices<<<gridDim: grid_size, blockDim: block_size, sharedMem: 0, stream>>>(indices: temp_indices, ncols, nrows);
43
44 const dim3 offset_grid((nrows + block_size - 1) / block_size);
45 init_offsets<<<gridDim: offset_grid, blockDim: block_size, sharedMem: 0, stream>>>(offsets: d_offsets, ncols, nrows);
46
47 cudaMemcpyAsync(dst: temp_keys, src: x, count: ncols * nrows * sizeof(float), kind: cudaMemcpyDeviceToDevice, stream);
48
49 size_t temp_storage_bytes = 0;
50
51 if (order == GGML_SORT_ORDER_ASC) {
52 DeviceSegmentedRadixSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
53 temp_indices, dst, // values (indices)
54 ncols * nrows, nrows, // num items, num segments
55 d_offsets, d_offsets + 1, 0, sizeof(float) * 8, // all bits
56 stream);
57 } else {
58 DeviceSegmentedRadixSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices,
59 dst, ncols * nrows, nrows, d_offsets, d_offsets + 1, 0,
60 sizeof(float) * 8, stream);
61 }
62
63 ggml_cuda_pool_alloc<uint8_t> temp_storage_alloc(pool, temp_storage_bytes);
64 void * d_temp_storage = temp_storage_alloc.get();
65
66 if (order == GGML_SORT_ORDER_ASC) {
67 DeviceSegmentedRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst,
68 ncols * nrows, nrows, d_offsets, d_offsets + 1, 0, sizeof(float) * 8,
69 stream);
70 } else {
71 DeviceSegmentedRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys,
72 temp_indices, dst, ncols * nrows, nrows, d_offsets, d_offsets + 1,
73 0, sizeof(float) * 8, stream);
74 }
75}
76#endif // GGML_CUDA_USE_CUB
77
78// Bitonic sort implementation
79template<typename T>
80static inline __device__ void ggml_cuda_swap(T & a, T & b) {
81 T tmp = a;
82 a = b;
83 b = tmp;
84}
85
86template<ggml_sort_order order>
87static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int ncols, int ncols_pad) {
88 // bitonic sort
89 int col = threadIdx.x;
90 int row = blockIdx.x;
91
92 if (col >= ncols_pad) {
93 return;
94 }
95
96 const float * x_row = x + row * ncols;
97 extern __shared__ int dst_row[];
98
99 // initialize indices
100 dst_row[col] = col;
101
102 __syncthreads();
103
104 for (int k = 2; k <= ncols_pad; k *= 2) {
105 for (int j = k / 2; j > 0; j /= 2) {
106 int ixj = col ^ j;
107 if (ixj > col) {
108 if ((col & k) == 0) {
109 if (dst_row[col] >= ncols ||
110 (dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ?
111 x_row[dst_row[col]] > x_row[dst_row[ixj]] :
112 x_row[dst_row[col]] < x_row[dst_row[ixj]]))
113 ) {
114 ggml_cuda_swap(a&: dst_row[col], b&: dst_row[ixj]);
115 }
116 } else {
117 if (dst_row[ixj] >= ncols ||
118 (dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ?
119 x_row[dst_row[col]] < x_row[dst_row[ixj]] :
120 x_row[dst_row[col]] > x_row[dst_row[ixj]]))
121 ) {
122 ggml_cuda_swap(a&: dst_row[col], b&: dst_row[ixj]);
123 }
124 }
125 }
126 __syncthreads();
127 }
128 }
129
130 // copy the result to dst without the padding
131 if (col < ncols) {
132 dst[row * ncols + col] = dst_row[col];
133 }
134}
135
136static int next_power_of_2(int x) {
137 int n = 1;
138 while (n < x) {
139 n *= 2;
140 }
141 return n;
142}
143
144static void argsort_f32_i32_cuda_bitonic(const float * x,
145 int * dst,
146 const int ncols,
147 const int nrows,
148 ggml_sort_order order,
149 cudaStream_t stream) {
150 // bitonic sort requires ncols to be power of 2
151 const int ncols_pad = next_power_of_2(x: ncols);
152
153 const dim3 block_dims(ncols_pad, 1, 1);
154 const dim3 block_nums(nrows, 1, 1);
155 const size_t shared_mem = ncols_pad * sizeof(int);
156
157 // FIXME: this limit could be raised by ~2-4x on Ampere or newer
158 GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb);
159
160 if (order == GGML_SORT_ORDER_ASC) {
161 k_argsort_f32_i32<GGML_SORT_ORDER_ASC>
162 <<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
163 } else if (order == GGML_SORT_ORDER_DESC) {
164 k_argsort_f32_i32<GGML_SORT_ORDER_DESC>
165 <<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
166 } else {
167 GGML_ABORT("fatal error");
168 }
169}
170
171void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
172 const ggml_tensor * src0 = dst->src[0];
173 const float * src0_d = (const float *)src0->data;
174 float * dst_d = (float *)dst->data;
175 cudaStream_t stream = ctx.stream();
176
177 GGML_ASSERT(src0->type == GGML_TYPE_F32);
178 GGML_ASSERT( dst->type == GGML_TYPE_I32);
179 GGML_ASSERT(ggml_is_contiguous(src0));
180
181 const int64_t ncols = src0->ne[0];
182 const int64_t nrows = ggml_nrows(src0);
183
184 enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
185
186#ifdef GGML_CUDA_USE_CUB
187 const int ncols_pad = next_power_of_2(x: ncols);
188 const size_t shared_mem = ncols_pad * sizeof(int);
189 const size_t max_shared_mem = ggml_cuda_info().devices[ggml_cuda_get_device()].smpb;
190
191 if (shared_mem > max_shared_mem || ncols > 1024) {
192 ggml_cuda_pool & pool = ctx.pool();
193 argsort_f32_i32_cuda_cub(pool, src0_d, (int *) dst_d, ncols, nrows, order, stream);
194 } else {
195 argsort_f32_i32_cuda_bitonic(src0_d, (int *) dst_d, ncols, nrows, order, stream);
196 }
197#else
198 argsort_f32_i32_cuda_bitonic(src0_d, (int *) dst_d, ncols, nrows, order, stream);
199#endif
200}
201