| 1 | #include "common.cuh" |
|---|---|
| 2 | |
| 3 | // Row reduction kernel template - compute sum (norm=false) or mean (norm=true) |
| 4 | template <bool norm> |
| 5 | static __global__ void reduce_rows_f32(const float * __restrict__ x, float * __restrict__ dst, const int ncols) { |
| 6 | const int row = blockIdx.x; |
| 7 | const int col = threadIdx.x; |
| 8 | |
| 9 | float sum = 0.0f; |
| 10 | const int num_unroll = 8; |
| 11 | float temp[num_unroll]; |
| 12 | float sum_temp[num_unroll] = { 0.0f }; |
| 13 | for (int i = col; i < ncols;) { |
| 14 | for (int j = 0; j < num_unroll; ++j) { |
| 15 | if (i < ncols) { |
| 16 | temp[j] = x[row * ncols + i]; |
| 17 | } else { |
| 18 | temp[j] = 0; |
| 19 | } |
| 20 | i += blockDim.x; |
| 21 | } |
| 22 | for (int j = 0; j < num_unroll; ++j) { |
| 23 | sum_temp[j] += temp[j]; |
| 24 | } |
| 25 | } |
| 26 | for (int j = 0; j < num_unroll; ++j) { |
| 27 | sum += sum_temp[j]; |
| 28 | } |
| 29 | |
| 30 | // sum up partial sums |
| 31 | sum = warp_reduce_sum(x: sum); |
| 32 | if (blockDim.x > WARP_SIZE) { |
| 33 | assert((blockDim.x <= 1024) && (blockDim.x % WARP_SIZE) == 0); |
| 34 | __shared__ float s_sum[32]; |
| 35 | const int warp_id = threadIdx.x / WARP_SIZE; |
| 36 | const int lane_id = threadIdx.x % WARP_SIZE; |
| 37 | if (lane_id == 0) { |
| 38 | s_sum[warp_id] = sum; |
| 39 | } |
| 40 | __syncthreads(); |
| 41 | sum = 0.0f; |
| 42 | if (lane_id < (static_cast<int>(blockDim.x) / WARP_SIZE)) { |
| 43 | sum = s_sum[lane_id]; |
| 44 | } |
| 45 | sum = warp_reduce_sum(x: sum); |
| 46 | } |
| 47 | |
| 48 | if (col != 0) { |
| 49 | return; |
| 50 | } |
| 51 | |
| 52 | dst[row] = norm ? sum / ncols : sum; |
| 53 | } |
| 54 |