1#include "common.cuh"
2
3// Row reduction kernel template - compute sum (norm=false) or mean (norm=true)
4template <bool norm>
5static __global__ void reduce_rows_f32(const float * __restrict__ x, float * __restrict__ dst, const int ncols) {
6 const int row = blockIdx.x;
7 const int col = threadIdx.x;
8
9 float sum = 0.0f;
10 const int num_unroll = 8;
11 float temp[num_unroll];
12 float sum_temp[num_unroll] = { 0.0f };
13 for (int i = col; i < ncols;) {
14 for (int j = 0; j < num_unroll; ++j) {
15 if (i < ncols) {
16 temp[j] = x[row * ncols + i];
17 } else {
18 temp[j] = 0;
19 }
20 i += blockDim.x;
21 }
22 for (int j = 0; j < num_unroll; ++j) {
23 sum_temp[j] += temp[j];
24 }
25 }
26 for (int j = 0; j < num_unroll; ++j) {
27 sum += sum_temp[j];
28 }
29
30 // sum up partial sums
31 sum = warp_reduce_sum(x: sum);
32 if (blockDim.x > WARP_SIZE) {
33 assert((blockDim.x <= 1024) && (blockDim.x % WARP_SIZE) == 0);
34 __shared__ float s_sum[32];
35 const int warp_id = threadIdx.x / WARP_SIZE;
36 const int lane_id = threadIdx.x % WARP_SIZE;
37 if (lane_id == 0) {
38 s_sum[warp_id] = sum;
39 }
40 __syncthreads();
41 sum = 0.0f;
42 if (lane_id < (static_cast<int>(blockDim.x) / WARP_SIZE)) {
43 sum = s_sum[lane_id];
44 }
45 sum = warp_reduce_sum(x: sum);
46 }
47
48 if (col != 0) {
49 return;
50 }
51
52 dst[row] = norm ? sum / ncols : sum;
53}
54