| 1 | #include <algorithm> |
| 2 | #include <cstdint> |
| 3 | |
| 4 | #include "argmax.cuh" |
| 5 | #include "common.cuh" |
| 6 | #include "sum.cuh" |
| 7 | |
| 8 | static __global__ void argmax_f32(const float * __restrict__ x, int32_t * __restrict__ dst, const int64_t ncols) { |
| 9 | const int64_t row = blockIdx.x; |
| 10 | |
| 11 | float maxval = -FLT_MAX; |
| 12 | int argmax = -1; |
| 13 | const float * rowx = x + row * ncols; |
| 14 | |
| 15 | for (int32_t col = threadIdx.x; col < ncols; col += blockDim.x) { |
| 16 | const float val = rowx[col]; |
| 17 | if (val > maxval) { |
| 18 | maxval = val; |
| 19 | argmax = col; |
| 20 | } |
| 21 | } |
| 22 | |
| 23 | #pragma unroll |
| 24 | for (int offset = 16; offset > 0; offset >>= 1) { |
| 25 | const float val = __shfl_xor_sync(mask: 0xFFFFFFFF, val: maxval, offset: offset, WARP_SIZE); |
| 26 | const int col = __shfl_xor_sync(mask: 0xFFFFFFFF, val: argmax, offset: offset, WARP_SIZE); |
| 27 | if (val > maxval) { |
| 28 | maxval = val; |
| 29 | argmax = col; |
| 30 | } |
| 31 | } |
| 32 | |
| 33 | const int n_warps = blockDim.x / WARP_SIZE; |
| 34 | const int lane_id = threadIdx.x % WARP_SIZE; |
| 35 | const int warp_id = threadIdx.x / WARP_SIZE; |
| 36 | if (n_warps > 1) { |
| 37 | constexpr int max_warps = 1024 / WARP_SIZE; |
| 38 | __shared__ float shared_maxval[max_warps]; |
| 39 | __shared__ int shared_argmax[max_warps]; |
| 40 | if (lane_id == 0) { |
| 41 | shared_maxval[warp_id] = maxval; |
| 42 | shared_argmax[warp_id] = argmax; |
| 43 | } |
| 44 | |
| 45 | __syncthreads(); |
| 46 | |
| 47 | if (warp_id == 0) { |
| 48 | if (lane_id < n_warps) { |
| 49 | maxval = shared_maxval[lane_id]; |
| 50 | argmax = shared_argmax[lane_id]; |
| 51 | } |
| 52 | #pragma unroll |
| 53 | for (int offset = 16; offset > 0; offset >>= 1) { |
| 54 | const float val = __shfl_xor_sync(mask: 0xFFFFFFFF, val: maxval, offset: offset, WARP_SIZE); |
| 55 | const int col = __shfl_xor_sync(mask: 0xFFFFFFFF, val: argmax, offset: offset, WARP_SIZE); |
| 56 | if (val > maxval) { |
| 57 | maxval = val; |
| 58 | argmax = col; |
| 59 | } |
| 60 | } |
| 61 | } |
| 62 | } |
| 63 | |
| 64 | if (warp_id == 0 && lane_id == 0) { |
| 65 | dst[row] = argmax; |
| 66 | } |
| 67 | } |
| 68 | |
| 69 | void ggml_cuda_argmax(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
| 70 | const ggml_tensor * src0 = dst->src[0]; |
| 71 | |
| 72 | GGML_ASSERT(src0->type == GGML_TYPE_F32); |
| 73 | GGML_ASSERT( dst->type == GGML_TYPE_I32); |
| 74 | |
| 75 | GGML_ASSERT(ggml_is_contiguous(src0)); |
| 76 | |
| 77 | const int64_t ne00 = src0->ne[0]; |
| 78 | const int64_t nrows = ggml_nrows(src0); |
| 79 | |
| 80 | const float * src0_d = (const float *) src0->data; |
| 81 | int32_t * dst_d = (int32_t *) dst->data; |
| 82 | |
| 83 | cudaStream_t stream = ctx.stream(); |
| 84 | |
| 85 | const int64_t num_blocks = nrows; |
| 86 | const int64_t num_threads = std::min<int64_t>(a: 1024, b: (ne00 + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE); |
| 87 | const dim3 blocks_dim(num_threads, 1, 1); |
| 88 | const dim3 blocks_num(num_blocks, 1, 1); |
| 89 | |
| 90 | argmax_f32<<<gridDim: blocks_num, blockDim: blocks_dim, sharedMem: 0, stream>>>(x: src0_d, dst: dst_d, ncols: ne00); |
| 91 | } |
| 92 | |