1#include "common.cuh"
2#include "ggml.h"
3#include "softmax.cuh"
4#include <cstdint>
5#include <utility>
6
7template <typename T>
8static __device__ __forceinline__ float t2f32(T val) {
9 return (float) val;
10}
11
12template <>
13__device__ float __forceinline__ t2f32<half>(half val) {
14 return __half2float(a: val);
15}
16
17struct soft_max_params {
18
19 int64_t nheads;
20 uint32_t n_head_log2;
21 int64_t ncols;
22 int64_t nrows_x;
23 int64_t nrows_y;
24 int64_t ne00;
25 int64_t ne01;
26 int64_t ne02;
27 int64_t ne03;
28 int64_t nb11;
29 int64_t nb12;
30 int64_t nb13;
31
32 int64_t ne12;
33 int64_t ne13;
34 float scale;
35 float max_bias;
36 float m0;
37 float m1;
38};
39
40// When ncols_template == 0 the bounds for the loops in this function are not known and can't be unrolled.
41// As we want to keep pragma unroll for all other cases we supress the clang transformation warning here.
42#ifdef __clang__
43#pragma clang diagnostic push
44#pragma clang diagnostic ignored "-Wpass-failed"
45#endif // __clang__
46template <bool use_shared, int ncols_template, int block_size_template, typename T>
47static __global__ void soft_max_f32(
48 const float * x, const T * mask, const float * sinks, float * dst, const soft_max_params p) {
49 const int ncols = ncols_template == 0 ? p.ncols : ncols_template;
50
51 const int tid = threadIdx.x;
52
53 const int64_t i03 = blockIdx.z;
54 const int64_t i02 = blockIdx.y;
55 const int64_t i01 = blockIdx.x;
56
57 //TODO: noncontigous inputs/outputs
58 const int rowx = blockIdx.x + blockIdx.y * gridDim.x + blockIdx.z * gridDim.x * gridDim.y;
59
60 const int64_t i11 = i01;
61 const int64_t i12 = i02 % p.ne12;
62 const int64_t i13 = i03 % p.ne13;
63
64 x += int64_t(rowx)*ncols;
65 mask += (i11*p.nb11 + i12*p.nb12 + i13*p.nb13) / sizeof(T) * (mask != nullptr);
66 dst += int64_t(rowx)*ncols;
67
68 const int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
69
70 const int warp_id = threadIdx.x / WARP_SIZE;
71 const int lane_id = threadIdx.x % WARP_SIZE;
72
73 const float slope = get_alibi_slope(max_bias: p.max_bias, h: i02, n_head_log2: p.n_head_log2, m0: p.m0, m1: p.m1);
74
75 extern __shared__ float data_soft_max_f32[];
76 float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication
77 // shared memory buffer to cache values between iterations:
78 float * vals = use_shared ? buf_iw + WARP_SIZE : dst;
79
80 float max_val = sinks ? sinks[i02] : -INFINITY;
81
82#pragma unroll
83 for (int col0 = 0; col0 < ncols; col0 += block_size) {
84 const int col = col0 + tid;
85
86 if (ncols_template == 0 && col >= ncols) {
87 break;
88 }
89
90 const float val = x[col]*p.scale + (mask ? slope*t2f32(mask[col]) : 0.0f);
91
92 vals[col] = val;
93 max_val = max(a: max_val, b: val);
94 }
95
96 // find the max value in the block
97 max_val = warp_reduce_max(x: max_val);
98 if (block_size > WARP_SIZE) {
99 if (warp_id == 0) {
100 buf_iw[lane_id] = -INFINITY;
101 }
102 __syncthreads();
103
104 if (lane_id == 0) {
105 buf_iw[warp_id] = max_val;
106 }
107 __syncthreads();
108
109 max_val = buf_iw[lane_id];
110 max_val = warp_reduce_max(x: max_val);
111 }
112
113 float tmp = 0.0f; // partial sum
114
115#pragma unroll
116 for (int col0 = 0; col0 < ncols; col0 += block_size) {
117 const int col = col0 + tid;
118
119 if (ncols_template == 0 && col >= ncols) {
120 break;
121 }
122
123 const float val = expf(a: vals[col] - max_val);
124 tmp += val;
125 vals[col] = val;
126 }
127
128 // find the sum of exps in the block
129 tmp = warp_reduce_sum(x: tmp);
130 if (block_size > WARP_SIZE) {
131 __syncthreads();
132 if (warp_id == 0) {
133 buf_iw[lane_id] = 0.0f;
134 }
135 __syncthreads();
136
137 if (lane_id == 0) {
138 buf_iw[warp_id] = tmp;
139 }
140 __syncthreads();
141
142 tmp = buf_iw[lane_id];
143 tmp = warp_reduce_sum(x: tmp);
144 }
145
146 if (sinks) {
147 tmp += expf(a: sinks[i02] - max_val);
148 }
149
150 const float inv_sum = 1.0f / tmp;
151
152#pragma unroll
153 for (int col0 = 0; col0 < ncols; col0 += block_size) {
154 const int col = col0 + tid;
155
156 if (ncols_template == 0 && col >= ncols) {
157 return;
158 }
159
160 dst[col] = vals[col] * inv_sum;
161 }
162}
163#ifdef __clang__
164#pragma clang diagnostic pop
165#endif // __clang__
166
167static __global__ void soft_max_back_f32(
168 const float * grad, const float * dstf, float * dst, const int ncols, const float scale) {
169 const int tid = threadIdx.x;
170 const int rowx = blockIdx.x;
171
172 grad += int64_t(rowx)*ncols;
173 dstf += int64_t(rowx)*ncols;
174 dst += int64_t(rowx)*ncols;
175
176 float dgf_dot = 0.0f; // dot product of dst from forward pass and gradients
177
178 for (int col = tid; col < ncols; col += WARP_SIZE) {
179 dgf_dot += dstf[col]*grad[col];
180 }
181
182 dgf_dot = warp_reduce_sum(x: dgf_dot);
183
184 for (int col = tid; col < ncols; col += WARP_SIZE) {
185 dst[col] = scale * (grad[col] - dgf_dot) * dstf[col];
186 }
187}
188
189template<int... Ns, typename T>
190static void launch_soft_max_kernels(const float * x, const T * mask, const float * sinks, float * dst,
191 const soft_max_params & p, cudaStream_t stream, dim3 block_dims, dim3 block_nums, size_t nbytes_shared)
192{
193 const int id = ggml_cuda_get_device();
194 const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
195
196 auto launch_kernel = [=](auto I) -> bool {
197 constexpr int ncols = decltype(I)::value;
198 constexpr int block = (ncols > 1024 ? 1024 : ncols);
199
200 if (p.ncols == ncols) {
201 CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32<true, ncols, block, T>), smpbo);
202 soft_max_f32<true, ncols, block><<<gridDim: block_nums, blockDim: block_dims, sharedMem: nbytes_shared, stream>>>
203 (x, mask, sinks, dst, p);
204 return true;
205 }
206 return false;
207 };
208
209 // unary fold over launch_kernel
210 if ((launch_kernel(std::integral_constant<int, Ns>{}) || ...)) {
211 return;
212 }
213
214 //default case
215 CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32<true, 0, 0, T>), smpbo);
216 soft_max_f32<true, 0, 0><<<gridDim: block_nums, blockDim: block_dims, sharedMem: nbytes_shared, stream>>>(x, mask, sinks, dst, p);
217}
218
219
220template<typename T>
221static void soft_max_f32_cuda(const float * x, const T * mask, const float * sinks, float * dst, const soft_max_params & params, cudaStream_t stream) {
222 int nth = WARP_SIZE;
223 const int64_t ncols_x = params.ncols;
224
225 while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
226 const dim3 block_dims(nth, 1, 1);
227 const dim3 block_nums(params.ne01, params.ne02, params.ne03);
228 const size_t nbytes_shared = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float);
229 static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
230
231
232 const int id = ggml_cuda_get_device();
233 const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
234
235
236 if (nbytes_shared <= smpbo) {
237 launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>(x, mask, sinks, dst, params, stream, block_dims, block_nums, nbytes_shared);
238 } else {
239 const size_t nbytes_shared_low = WARP_SIZE*sizeof(float);
240 soft_max_f32<false, 0, 0><<<gridDim: block_nums, blockDim: block_dims, sharedMem: nbytes_shared_low, stream>>>(x, mask, sinks, dst, params);
241 }
242}
243
244static void soft_max_back_f32_cuda(
245 const float * grad, const float * dstf, float * dst,
246 const int ncols, const int nrows, const float scale, cudaStream_t stream) {
247 const dim3 block_dims(WARP_SIZE, 1, 1);
248 const dim3 block_nums(nrows, 1, 1);
249
250 soft_max_back_f32<<<gridDim: block_nums, blockDim: block_dims, sharedMem: 0, stream>>>(grad, dstf, dst, ncols, scale);
251}
252
253void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
254 const ggml_tensor * src0 = dst->src[0];
255 const ggml_tensor * src1 = dst->src[1];
256 const ggml_tensor * src2 = dst->src[2];
257
258 const float * src0_d = (const float *) src0->data;
259 const void * src1_d = src1 ? (const void *) src1->data : nullptr;
260 const void * src2_d = src2 ? (const void *) src2->data : nullptr;
261 float * dst_d = (float *) dst->data;
262
263 cudaStream_t stream = ctx.stream();
264
265 GGML_ASSERT(src0->type == GGML_TYPE_F32);
266 GGML_ASSERT( dst->type == GGML_TYPE_F32);
267
268 GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
269
270 const int64_t nrows_x = ggml_nrows(src0);
271 const int64_t nrows_y = src0->ne[1];
272
273 const int64_t ne00 = src0->ne[0];
274
275 float scale = 1.0f;
276 float max_bias = 0.0f;
277
278 memcpy(dest: &scale, src: (const float *) dst->op_params + 0, n: sizeof(float));
279 memcpy(dest: &max_bias, src: (const float *) dst->op_params + 1, n: sizeof(float));
280
281 const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
282
283 const int64_t nb11 = src1 ? src1->nb[1] : 1;
284 const int64_t nb12 = src1 ? src1->nb[2] : 1;
285 const int64_t nb13 = src1 ? src1->nb[3] : 1;
286
287 const int64_t ne12 = src1 ? src1->ne[2] : 1;
288 const int64_t ne13 = src1 ? src1->ne[3] : 1;
289
290 const uint32_t n_head = src0->ne[2];
291 const uint32_t n_head_log2 = 1u << (uint32_t) floorf(x: log2f(x: (float) n_head));
292
293 const float m0 = powf(x: 2.0f, y: -(max_bias ) / n_head_log2);
294 const float m1 = powf(x: 2.0f, y: -(max_bias / 2.0f) / n_head_log2);
295
296
297 soft_max_params params = {};
298 params.nheads = src0->ne[2];
299 params.n_head_log2 = n_head_log2;
300 params.ncols = ne00;
301 params.nrows_x = nrows_x;
302 params.nrows_y = nrows_y;
303 params.ne00 = src0->ne[0];
304 params.ne01 = src0->ne[1];
305 params.ne02 = src0->ne[2];
306 params.ne03 = src0->ne[3];
307 params.nb11 = nb11;
308 params.nb12 = nb12;
309 params.nb13 = nb13;
310 params.ne12 = ne12;
311 params.ne13 = ne13;
312 params.scale = scale;
313 params.max_bias = max_bias;
314 params.m0 = m0;
315 params.m1 = m1;
316
317 if (use_f16) {
318 soft_max_f32_cuda(x: src0_d, mask: (const half *) src1_d, sinks: (const float *) src2_d, dst: dst_d, params, stream);
319 } else {
320 soft_max_f32_cuda(x: src0_d, mask: (const float *) src1_d, sinks: (const float *) src2_d, dst: dst_d, params, stream);
321 }
322}
323
324void ggml_cuda_op_soft_max_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
325 const ggml_tensor * src0 = dst->src[0]; // grad
326 const ggml_tensor * src1 = dst->src[1]; // forward pass output
327
328 const float * src0_d = (const float *) src0->data;
329 const float * src1_d = (const float *) src1->data;
330 float * dst_d = (float *) dst->data;
331
332 cudaStream_t stream = ctx.stream();
333
334 GGML_ASSERT(src0->type == GGML_TYPE_F32);
335 GGML_ASSERT(src1->type == GGML_TYPE_F32);
336 GGML_ASSERT( dst->type == GGML_TYPE_F32);
337
338 const int64_t ncols = src0->ne[0];
339 const int64_t nrows = ggml_nrows(src0);
340
341 float scale = 1.0f;
342 float max_bias = 0.0f;
343
344 memcpy(dest: &scale, src: (const float *) dst->op_params + 0, n: sizeof(float));
345 memcpy(dest: &max_bias, src: (const float *) dst->op_params + 1, n: sizeof(float));
346
347 GGML_ASSERT(max_bias == 0.0f);
348
349 soft_max_back_f32_cuda(grad: src0_d, dstf: src1_d, dst: dst_d, ncols, nrows, scale, stream);
350}
351