| 1 | #include "common.cuh" |
| 2 | #include "ggml.h" |
| 3 | #include "softmax.cuh" |
| 4 | #include <cstdint> |
| 5 | #include <utility> |
| 6 | |
| 7 | template <typename T> |
| 8 | static __device__ __forceinline__ float t2f32(T val) { |
| 9 | return (float) val; |
| 10 | } |
| 11 | |
| 12 | template <> |
| 13 | __device__ float __forceinline__ t2f32<half>(half val) { |
| 14 | return __half2float(a: val); |
| 15 | } |
| 16 | |
| 17 | struct soft_max_params { |
| 18 | |
| 19 | int64_t nheads; |
| 20 | uint32_t n_head_log2; |
| 21 | int64_t ncols; |
| 22 | int64_t nrows_x; |
| 23 | int64_t nrows_y; |
| 24 | int64_t ne00; |
| 25 | int64_t ne01; |
| 26 | int64_t ne02; |
| 27 | int64_t ne03; |
| 28 | int64_t nb11; |
| 29 | int64_t nb12; |
| 30 | int64_t nb13; |
| 31 | |
| 32 | int64_t ne12; |
| 33 | int64_t ne13; |
| 34 | float scale; |
| 35 | float max_bias; |
| 36 | float m0; |
| 37 | float m1; |
| 38 | }; |
| 39 | |
| 40 | // When ncols_template == 0 the bounds for the loops in this function are not known and can't be unrolled. |
| 41 | // As we want to keep pragma unroll for all other cases we supress the clang transformation warning here. |
| 42 | #ifdef __clang__ |
| 43 | #pragma clang diagnostic push |
| 44 | #pragma clang diagnostic ignored "-Wpass-failed" |
| 45 | #endif // __clang__ |
| 46 | template <bool use_shared, int ncols_template, int block_size_template, typename T> |
| 47 | static __global__ void soft_max_f32( |
| 48 | const float * x, const T * mask, const float * sinks, float * dst, const soft_max_params p) { |
| 49 | const int ncols = ncols_template == 0 ? p.ncols : ncols_template; |
| 50 | |
| 51 | const int tid = threadIdx.x; |
| 52 | |
| 53 | const int64_t i03 = blockIdx.z; |
| 54 | const int64_t i02 = blockIdx.y; |
| 55 | const int64_t i01 = blockIdx.x; |
| 56 | |
| 57 | //TODO: noncontigous inputs/outputs |
| 58 | const int rowx = blockIdx.x + blockIdx.y * gridDim.x + blockIdx.z * gridDim.x * gridDim.y; |
| 59 | |
| 60 | const int64_t i11 = i01; |
| 61 | const int64_t i12 = i02 % p.ne12; |
| 62 | const int64_t i13 = i03 % p.ne13; |
| 63 | |
| 64 | x += int64_t(rowx)*ncols; |
| 65 | mask += (i11*p.nb11 + i12*p.nb12 + i13*p.nb13) / sizeof(T) * (mask != nullptr); |
| 66 | dst += int64_t(rowx)*ncols; |
| 67 | |
| 68 | const int block_size = block_size_template == 0 ? blockDim.x : block_size_template; |
| 69 | |
| 70 | const int warp_id = threadIdx.x / WARP_SIZE; |
| 71 | const int lane_id = threadIdx.x % WARP_SIZE; |
| 72 | |
| 73 | const float slope = get_alibi_slope(max_bias: p.max_bias, h: i02, n_head_log2: p.n_head_log2, m0: p.m0, m1: p.m1); |
| 74 | |
| 75 | extern __shared__ float data_soft_max_f32[]; |
| 76 | float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication |
| 77 | // shared memory buffer to cache values between iterations: |
| 78 | float * vals = use_shared ? buf_iw + WARP_SIZE : dst; |
| 79 | |
| 80 | float max_val = sinks ? sinks[i02] : -INFINITY; |
| 81 | |
| 82 | #pragma unroll |
| 83 | for (int col0 = 0; col0 < ncols; col0 += block_size) { |
| 84 | const int col = col0 + tid; |
| 85 | |
| 86 | if (ncols_template == 0 && col >= ncols) { |
| 87 | break; |
| 88 | } |
| 89 | |
| 90 | const float val = x[col]*p.scale + (mask ? slope*t2f32(mask[col]) : 0.0f); |
| 91 | |
| 92 | vals[col] = val; |
| 93 | max_val = max(a: max_val, b: val); |
| 94 | } |
| 95 | |
| 96 | // find the max value in the block |
| 97 | max_val = warp_reduce_max(x: max_val); |
| 98 | if (block_size > WARP_SIZE) { |
| 99 | if (warp_id == 0) { |
| 100 | buf_iw[lane_id] = -INFINITY; |
| 101 | } |
| 102 | __syncthreads(); |
| 103 | |
| 104 | if (lane_id == 0) { |
| 105 | buf_iw[warp_id] = max_val; |
| 106 | } |
| 107 | __syncthreads(); |
| 108 | |
| 109 | max_val = buf_iw[lane_id]; |
| 110 | max_val = warp_reduce_max(x: max_val); |
| 111 | } |
| 112 | |
| 113 | float tmp = 0.0f; // partial sum |
| 114 | |
| 115 | #pragma unroll |
| 116 | for (int col0 = 0; col0 < ncols; col0 += block_size) { |
| 117 | const int col = col0 + tid; |
| 118 | |
| 119 | if (ncols_template == 0 && col >= ncols) { |
| 120 | break; |
| 121 | } |
| 122 | |
| 123 | const float val = expf(a: vals[col] - max_val); |
| 124 | tmp += val; |
| 125 | vals[col] = val; |
| 126 | } |
| 127 | |
| 128 | // find the sum of exps in the block |
| 129 | tmp = warp_reduce_sum(x: tmp); |
| 130 | if (block_size > WARP_SIZE) { |
| 131 | __syncthreads(); |
| 132 | if (warp_id == 0) { |
| 133 | buf_iw[lane_id] = 0.0f; |
| 134 | } |
| 135 | __syncthreads(); |
| 136 | |
| 137 | if (lane_id == 0) { |
| 138 | buf_iw[warp_id] = tmp; |
| 139 | } |
| 140 | __syncthreads(); |
| 141 | |
| 142 | tmp = buf_iw[lane_id]; |
| 143 | tmp = warp_reduce_sum(x: tmp); |
| 144 | } |
| 145 | |
| 146 | if (sinks) { |
| 147 | tmp += expf(a: sinks[i02] - max_val); |
| 148 | } |
| 149 | |
| 150 | const float inv_sum = 1.0f / tmp; |
| 151 | |
| 152 | #pragma unroll |
| 153 | for (int col0 = 0; col0 < ncols; col0 += block_size) { |
| 154 | const int col = col0 + tid; |
| 155 | |
| 156 | if (ncols_template == 0 && col >= ncols) { |
| 157 | return; |
| 158 | } |
| 159 | |
| 160 | dst[col] = vals[col] * inv_sum; |
| 161 | } |
| 162 | } |
| 163 | #ifdef __clang__ |
| 164 | #pragma clang diagnostic pop |
| 165 | #endif // __clang__ |
| 166 | |
| 167 | static __global__ void soft_max_back_f32( |
| 168 | const float * grad, const float * dstf, float * dst, const int ncols, const float scale) { |
| 169 | const int tid = threadIdx.x; |
| 170 | const int rowx = blockIdx.x; |
| 171 | |
| 172 | grad += int64_t(rowx)*ncols; |
| 173 | dstf += int64_t(rowx)*ncols; |
| 174 | dst += int64_t(rowx)*ncols; |
| 175 | |
| 176 | float dgf_dot = 0.0f; // dot product of dst from forward pass and gradients |
| 177 | |
| 178 | for (int col = tid; col < ncols; col += WARP_SIZE) { |
| 179 | dgf_dot += dstf[col]*grad[col]; |
| 180 | } |
| 181 | |
| 182 | dgf_dot = warp_reduce_sum(x: dgf_dot); |
| 183 | |
| 184 | for (int col = tid; col < ncols; col += WARP_SIZE) { |
| 185 | dst[col] = scale * (grad[col] - dgf_dot) * dstf[col]; |
| 186 | } |
| 187 | } |
| 188 | |
| 189 | template<int... Ns, typename T> |
| 190 | static void launch_soft_max_kernels(const float * x, const T * mask, const float * sinks, float * dst, |
| 191 | const soft_max_params & p, cudaStream_t stream, dim3 block_dims, dim3 block_nums, size_t nbytes_shared) |
| 192 | { |
| 193 | const int id = ggml_cuda_get_device(); |
| 194 | const size_t smpbo = ggml_cuda_info().devices[id].smpbo; |
| 195 | |
| 196 | auto launch_kernel = [=](auto I) -> bool { |
| 197 | constexpr int ncols = decltype(I)::value; |
| 198 | constexpr int block = (ncols > 1024 ? 1024 : ncols); |
| 199 | |
| 200 | if (p.ncols == ncols) { |
| 201 | CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32<true, ncols, block, T>), smpbo); |
| 202 | soft_max_f32<true, ncols, block><<<gridDim: block_nums, blockDim: block_dims, sharedMem: nbytes_shared, stream>>> |
| 203 | (x, mask, sinks, dst, p); |
| 204 | return true; |
| 205 | } |
| 206 | return false; |
| 207 | }; |
| 208 | |
| 209 | // unary fold over launch_kernel |
| 210 | if ((launch_kernel(std::integral_constant<int, Ns>{}) || ...)) { |
| 211 | return; |
| 212 | } |
| 213 | |
| 214 | //default case |
| 215 | CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32<true, 0, 0, T>), smpbo); |
| 216 | soft_max_f32<true, 0, 0><<<gridDim: block_nums, blockDim: block_dims, sharedMem: nbytes_shared, stream>>>(x, mask, sinks, dst, p); |
| 217 | } |
| 218 | |
| 219 | |
| 220 | template<typename T> |
| 221 | static void soft_max_f32_cuda(const float * x, const T * mask, const float * sinks, float * dst, const soft_max_params & params, cudaStream_t stream) { |
| 222 | int nth = WARP_SIZE; |
| 223 | const int64_t ncols_x = params.ncols; |
| 224 | |
| 225 | while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2; |
| 226 | const dim3 block_dims(nth, 1, 1); |
| 227 | const dim3 block_nums(params.ne01, params.ne02, params.ne03); |
| 228 | const size_t nbytes_shared = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float); |
| 229 | static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted." ); |
| 230 | |
| 231 | |
| 232 | const int id = ggml_cuda_get_device(); |
| 233 | const size_t smpbo = ggml_cuda_info().devices[id].smpbo; |
| 234 | |
| 235 | |
| 236 | if (nbytes_shared <= smpbo) { |
| 237 | launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>(x, mask, sinks, dst, params, stream, block_dims, block_nums, nbytes_shared); |
| 238 | } else { |
| 239 | const size_t nbytes_shared_low = WARP_SIZE*sizeof(float); |
| 240 | soft_max_f32<false, 0, 0><<<gridDim: block_nums, blockDim: block_dims, sharedMem: nbytes_shared_low, stream>>>(x, mask, sinks, dst, params); |
| 241 | } |
| 242 | } |
| 243 | |
| 244 | static void soft_max_back_f32_cuda( |
| 245 | const float * grad, const float * dstf, float * dst, |
| 246 | const int ncols, const int nrows, const float scale, cudaStream_t stream) { |
| 247 | const dim3 block_dims(WARP_SIZE, 1, 1); |
| 248 | const dim3 block_nums(nrows, 1, 1); |
| 249 | |
| 250 | soft_max_back_f32<<<gridDim: block_nums, blockDim: block_dims, sharedMem: 0, stream>>>(grad, dstf, dst, ncols, scale); |
| 251 | } |
| 252 | |
| 253 | void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
| 254 | const ggml_tensor * src0 = dst->src[0]; |
| 255 | const ggml_tensor * src1 = dst->src[1]; |
| 256 | const ggml_tensor * src2 = dst->src[2]; |
| 257 | |
| 258 | const float * src0_d = (const float *) src0->data; |
| 259 | const void * src1_d = src1 ? (const void *) src1->data : nullptr; |
| 260 | const void * src2_d = src2 ? (const void *) src2->data : nullptr; |
| 261 | float * dst_d = (float *) dst->data; |
| 262 | |
| 263 | cudaStream_t stream = ctx.stream(); |
| 264 | |
| 265 | GGML_ASSERT(src0->type == GGML_TYPE_F32); |
| 266 | GGML_ASSERT( dst->type == GGML_TYPE_F32); |
| 267 | |
| 268 | GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional |
| 269 | |
| 270 | const int64_t nrows_x = ggml_nrows(src0); |
| 271 | const int64_t nrows_y = src0->ne[1]; |
| 272 | |
| 273 | const int64_t ne00 = src0->ne[0]; |
| 274 | |
| 275 | float scale = 1.0f; |
| 276 | float max_bias = 0.0f; |
| 277 | |
| 278 | memcpy(dest: &scale, src: (const float *) dst->op_params + 0, n: sizeof(float)); |
| 279 | memcpy(dest: &max_bias, src: (const float *) dst->op_params + 1, n: sizeof(float)); |
| 280 | |
| 281 | const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); |
| 282 | |
| 283 | const int64_t nb11 = src1 ? src1->nb[1] : 1; |
| 284 | const int64_t nb12 = src1 ? src1->nb[2] : 1; |
| 285 | const int64_t nb13 = src1 ? src1->nb[3] : 1; |
| 286 | |
| 287 | const int64_t ne12 = src1 ? src1->ne[2] : 1; |
| 288 | const int64_t ne13 = src1 ? src1->ne[3] : 1; |
| 289 | |
| 290 | const uint32_t n_head = src0->ne[2]; |
| 291 | const uint32_t n_head_log2 = 1u << (uint32_t) floorf(x: log2f(x: (float) n_head)); |
| 292 | |
| 293 | const float m0 = powf(x: 2.0f, y: -(max_bias ) / n_head_log2); |
| 294 | const float m1 = powf(x: 2.0f, y: -(max_bias / 2.0f) / n_head_log2); |
| 295 | |
| 296 | |
| 297 | soft_max_params params = {}; |
| 298 | params.nheads = src0->ne[2]; |
| 299 | params.n_head_log2 = n_head_log2; |
| 300 | params.ncols = ne00; |
| 301 | params.nrows_x = nrows_x; |
| 302 | params.nrows_y = nrows_y; |
| 303 | params.ne00 = src0->ne[0]; |
| 304 | params.ne01 = src0->ne[1]; |
| 305 | params.ne02 = src0->ne[2]; |
| 306 | params.ne03 = src0->ne[3]; |
| 307 | params.nb11 = nb11; |
| 308 | params.nb12 = nb12; |
| 309 | params.nb13 = nb13; |
| 310 | params.ne12 = ne12; |
| 311 | params.ne13 = ne13; |
| 312 | params.scale = scale; |
| 313 | params.max_bias = max_bias; |
| 314 | params.m0 = m0; |
| 315 | params.m1 = m1; |
| 316 | |
| 317 | if (use_f16) { |
| 318 | soft_max_f32_cuda(x: src0_d, mask: (const half *) src1_d, sinks: (const float *) src2_d, dst: dst_d, params, stream); |
| 319 | } else { |
| 320 | soft_max_f32_cuda(x: src0_d, mask: (const float *) src1_d, sinks: (const float *) src2_d, dst: dst_d, params, stream); |
| 321 | } |
| 322 | } |
| 323 | |
| 324 | void ggml_cuda_op_soft_max_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
| 325 | const ggml_tensor * src0 = dst->src[0]; // grad |
| 326 | const ggml_tensor * src1 = dst->src[1]; // forward pass output |
| 327 | |
| 328 | const float * src0_d = (const float *) src0->data; |
| 329 | const float * src1_d = (const float *) src1->data; |
| 330 | float * dst_d = (float *) dst->data; |
| 331 | |
| 332 | cudaStream_t stream = ctx.stream(); |
| 333 | |
| 334 | GGML_ASSERT(src0->type == GGML_TYPE_F32); |
| 335 | GGML_ASSERT(src1->type == GGML_TYPE_F32); |
| 336 | GGML_ASSERT( dst->type == GGML_TYPE_F32); |
| 337 | |
| 338 | const int64_t ncols = src0->ne[0]; |
| 339 | const int64_t nrows = ggml_nrows(src0); |
| 340 | |
| 341 | float scale = 1.0f; |
| 342 | float max_bias = 0.0f; |
| 343 | |
| 344 | memcpy(dest: &scale, src: (const float *) dst->op_params + 0, n: sizeof(float)); |
| 345 | memcpy(dest: &max_bias, src: (const float *) dst->op_params + 1, n: sizeof(float)); |
| 346 | |
| 347 | GGML_ASSERT(max_bias == 0.0f); |
| 348 | |
| 349 | soft_max_back_f32_cuda(grad: src0_d, dstf: src1_d, dst: dst_d, ncols, nrows, scale, stream); |
| 350 | } |
| 351 | |