1#include "common.cuh"
2#include "cp-async.cuh"
3#include "mma.cuh"
4#include "fattn-common.cuh"
5
6using namespace ggml_cuda_mma;
7
8typedef tile<16, 8, half2> tile_A;
9typedef tile< 8, 8, half2> tile_B;
10typedef tile<16, 8, half2> tile_B_16;
11typedef tile<16, 8, float> tile_C_KQ;
12typedef tile<16, 16, float> tile_C_KQ_16;
13typedef tile<16, 4, half2> tile_C_VKQ;
14typedef tile<16, 8, half2> tile_C_VKQ_16;
15
16// Config options for specific head sizes.
17// Should not affect results, only speed/register pressure/shared memory use.
18//
19// nbatch_fa: number of KV rows per softmax rescaling of KQ rowsums and VKQ accumulators.
20// nwarps_max: maximum number of warps per CUDA block, up to 8 warps in total can run per SM (given enough shared memory).
21// Q_in_reg: whether the Q values should be kept permanently in registers.
22// nstages_target: targeted number of pipeline stages for cp_async (if available), 0 means synchronous data loading.
23// nbatch_K2: number of K half2 values in direction of DKQ to load in parallel.
24// nbatch_V2: number of V half2 values in direction of DV to load in parallel.
25// nbatch_combine: number of VKQ half2 values in direction of DV to combine in parallel.
26
27template <int DKQ, int DV>
28struct fattn_mma_f16_config;
29
30template <>
31struct fattn_mma_f16_config< 64, 64> {
32 static constexpr int nbatch_fa = 64;
33 static constexpr int nwarps_max = 4;
34 static constexpr bool Q_in_reg = true;
35 static constexpr int nstages_target = 2;
36
37 static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
38 return 32;
39 }
40
41 static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
42 return 32;
43 }
44
45 static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
46 return 32;
47 }
48
49 static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
50 return 32;
51 }
52
53 static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
54 return 32;
55 }
56
57 static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
58 return 32;
59 }
60};
61
62template <>
63struct fattn_mma_f16_config< 80, 80> {
64 static constexpr int nbatch_fa = 64;
65 static constexpr int nwarps_max = 4;
66 static constexpr bool Q_in_reg = true;
67 static constexpr int nstages_target = 2;
68
69 static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
70 return 40;
71 }
72
73 static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
74 return 40;
75 }
76
77 static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
78 return 40;
79 }
80
81 static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
82 return 40;
83 }
84
85 static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
86 return 40;
87 }
88
89 static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
90 return 40;
91 }
92};
93
94template <>
95struct fattn_mma_f16_config< 96, 96> {
96 static constexpr int nbatch_fa = 64;
97 static constexpr int nwarps_max = 4;
98 static constexpr bool Q_in_reg = true;
99 static constexpr int nstages_target = 2;
100
101 static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
102 return 48;
103 }
104
105 static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
106 return 48;
107 }
108
109 static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
110 return 48;
111 }
112
113 static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
114 return 48;
115 }
116
117 static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
118 return 48;
119 }
120
121 static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
122 return 48;
123 }
124};
125
126template <>
127struct fattn_mma_f16_config<112, 112> {
128 static constexpr int nbatch_fa = 64;
129 static constexpr int nwarps_max = 4;
130 static constexpr bool Q_in_reg = true;
131 static constexpr int nstages_target = 2;
132
133 static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
134 return 56;
135 }
136
137 static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
138 return 56;
139 }
140
141 static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
142 return 56;
143 }
144
145 static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
146 return 56;
147 }
148
149 static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
150 return 56;
151 }
152
153 static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
154 return 56;
155 }
156};
157
158template <>
159struct fattn_mma_f16_config<128, 128> {
160 static constexpr int nbatch_fa = 64;
161 static constexpr int nwarps_max = 4;
162 static constexpr bool Q_in_reg = true;
163 static constexpr int nstages_target = 2;
164
165 static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
166 return 64;
167 }
168
169 static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
170 return 64;
171 }
172
173 static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
174 return 64;
175 }
176
177 static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
178 return 64;
179 }
180
181 static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
182 return 64;
183 }
184
185 static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
186 return 64;
187 }
188};
189
190template <>
191struct fattn_mma_f16_config<256, 256> {
192 static constexpr int nbatch_fa = 32;
193 static constexpr int nwarps_max = 4;
194 static constexpr bool Q_in_reg = true;
195 static constexpr int nstages_target = 2;
196
197 static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
198 return 128;
199 }
200
201 static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
202 return 128;
203 }
204
205 static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
206 return 128;
207 }
208
209 static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
210 return 128;
211 }
212
213 static int get_nbatch_combine_host(const int cc, const int ncols) {
214 if (ggml_cuda_highest_compiled_arch(arch: cc) == GGML_CUDA_CC_TURING) {
215 return ncols <= 16 ? 128 : 64;
216 }
217 return 64;
218 }
219
220 static constexpr __device__ int get_nbatch_combine_device(int ncols) {
221#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
222 return ncols <= 16 ? 128 : 64;
223#else
224 GGML_UNUSED(ncols);
225 return 128;
226#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING
227 }
228};
229
230template <>
231struct fattn_mma_f16_config<576, 512> {
232 static constexpr int nbatch_fa = 32;
233 static constexpr int nwarps_max = 8;
234 static constexpr bool Q_in_reg = false;
235 static constexpr int nstages_target = 1;
236
237 static int get_nbatch_K2_host(const int cc, const int ncols) {
238 if (ggml_cuda_highest_compiled_arch(arch: cc) == GGML_CUDA_CC_TURING) {
239 return ncols <= 16 ? 96 : 160;
240 }
241 return ncols <= 16 ? 288 : 160;
242 }
243
244 static constexpr __device__ int get_nbatch_K2_device(int ncols) {
245#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
246 return ncols <= 16 ? 96 : 160;
247#else
248 return ncols <= 16 ? 288 : 160;
249#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING
250 }
251
252 static int get_nbatch_V2_host(const int cc, const int ncols) {
253 if (ggml_cuda_highest_compiled_arch(arch: cc) == GGML_CUDA_CC_TURING) {
254 return ncols <= 16 ? 64 : 128;
255 }
256 return ncols <= 16 ? 256 : 128;
257 }
258
259 static constexpr __device__ int get_nbatch_V2_device(int ncols) {
260#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
261 return ncols <= 16 ? 64 : 128;
262#else
263 return ncols <= 16 ? 256 : 128;
264#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING
265 }
266
267 static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
268 return 128;
269 }
270
271 static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
272 return 128;
273 }
274};
275
276// ------------------------------------------------------------------------------------------------------------------
277
278template<int stride_tile, int nwarps, int nbatch_fa, bool use_cp_async>
279static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
280 const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int D2, const int stride_KV) {
281
282 // K/V data is loaded with decreasing granularity for D for better memory bandwidth.
283 // The minimum granularity with cp.async is 16 bytes, with synchronous data loading it's 4 bytes.
284
285 if (use_cp_async) {
286 constexpr int preload = 64;
287 constexpr int h2_per_chunk = 16/sizeof(half2);
288 const int chunks_per_row = D2 / h2_per_chunk;
289
290 const unsigned int tile_KV_32 = ggml_cuda_cvta_generic_to_shared(generic_ptr: tile_KV);
291
292 auto load = [&] __device__ (auto n) {
293 const int stride_k = WARP_SIZE >> n;
294 const int k0_start = stride_k == WARP_SIZE ? 0 : chunks_per_row - chunks_per_row % (2*stride_k);
295 const int k0_stop = chunks_per_row - chunks_per_row % (1*stride_k);
296 const int stride_i = WARP_SIZE / stride_k;
297
298 if (k0_start == k0_stop) {
299 return;
300 }
301
302#pragma unroll
303 for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) {
304 const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
305
306 if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) {
307 break;
308 }
309
310#pragma unroll
311 for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
312 const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
313
314 cp_async_cg_16<preload>(dst: tile_KV_32 + i*(stride_tile*sizeof(half2)) + k*16, src: KV + i*stride_KV + k*h2_per_chunk);
315 }
316 }
317 };
318 ggml_cuda_unroll<5>{}(load);
319 } else {
320 static_assert(nbatch_fa % (4*nwarps) == 0, "out of bounds");
321 auto load = [&] __device__ (const int n) {
322 const int stride_k = WARP_SIZE >> n;
323 const int k0_start = stride_k == WARP_SIZE ? 0 : D2 - D2 % (2*stride_k);
324 const int k0_stop = D2 - D2 % (1*stride_k);
325 const int stride_i = WARP_SIZE / stride_k;
326
327 if (k0_start == k0_stop) {
328 return;
329 }
330
331#pragma unroll
332 for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) {
333 const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
334
335 if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) {
336 break;
337 }
338
339#pragma unroll
340 for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
341 const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
342
343 tile_KV[i*stride_tile + k] = KV[i*stride_KV + k];
344 }
345 }
346 };
347 ggml_cuda_unroll<3>{}(load);
348 }
349}
350
351template<int ncols1, int nwarps, int nbatch_fa, bool use_cp_async>
352static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
353 const half2 * const __restrict__ mask_h2, half2 * const __restrict__ tile_mask, const int stride_mask) {
354 static_assert(nbatch_fa == 2*WARP_SIZE || WARP_SIZE % nbatch_fa == 0, "bad KQ_per_iter");
355
356 if (use_cp_async) {
357 constexpr int preload = nbatch_fa >= 32 ? nbatch_fa * sizeof(half) : 64;
358 constexpr int cols_per_warp = 8*WARP_SIZE/nbatch_fa;
359 constexpr int stride_j = nwarps * cols_per_warp;
360
361 const unsigned int tile_mask_32 = ggml_cuda_cvta_generic_to_shared(generic_ptr: tile_mask);
362
363#pragma unroll
364 for (int j0 = 0; j0 < ncols1; j0 += stride_j) {
365 const int j = j0 + threadIdx.y*cols_per_warp +
366 (nbatch_fa == 2*WARP_SIZE ? threadIdx.x / (WARP_SIZE/4) : threadIdx.x / (WARP_SIZE/cols_per_warp));
367
368 if (j0 + stride_j > ncols1 && j >= ncols1) {
369 break;
370 }
371
372 const int i = 4 * (threadIdx.x % (nbatch_fa/8));
373
374 cp_async_cg_16<preload>(tile_mask_32 + j*(nbatch_fa*sizeof(half) + 16) + i*sizeof(half2), mask_h2 + j*stride_mask + i);
375 }
376 return;
377 }
378
379 constexpr int cols_per_warp = 2*WARP_SIZE/nbatch_fa;
380 constexpr int stride_j = nwarps * cols_per_warp;
381#pragma unroll
382 for (int j0 = 0; j0 < ncols1; j0 += stride_j) {
383 const int j = j0 + threadIdx.y*cols_per_warp + (nbatch_fa == 2*WARP_SIZE ? 0 : threadIdx.x / (WARP_SIZE/cols_per_warp));
384
385 if (j0 + stride_j > ncols1 && j >= ncols1) {
386 break;
387 }
388
389 const int i = nbatch_fa == 2*WARP_SIZE ? threadIdx.x : threadIdx.x % (WARP_SIZE/cols_per_warp);
390
391 tile_mask[j*(nbatch_fa/2 + 4) + i] = mask_h2[j*stride_mask + i];
392 }
393}
394
395template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles,
396 bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup, bool last_iter>
397static __device__ __forceinline__ void flash_attn_ext_f16_iter(
398 const float2 * const __restrict__ Q_f2,
399 const half2 * const __restrict__ K_h2,
400 const half2 * const __restrict__ V_h2,
401 const half2 * const __restrict__ mask_h2,
402 float2 * const __restrict__ dstk,
403 float2 * const __restrict__ dstk_fixup,
404 const float scale,
405 const float slope,
406 const float logit_softcap,
407 const int ne01,
408 const int ne02,
409 const int stride_K,
410 const int stride_V,
411 const int stride_mask,
412 half2 * const __restrict__ tile_Q,
413 half2 * const __restrict__ tile_K,
414 half2 * const __restrict__ tile_V,
415 half2 * const __restrict__ tile_mask,
416 const tile_B * const __restrict__ Q_B,
417 tile_C_VKQ * const __restrict__ VKQ_C,
418 float * const __restrict__ KQ_max,
419 float * const __restrict__ KQ_rowsum,
420 const int kb0) {
421#ifdef TURING_MMA_AVAILABLE
422 typedef fattn_mma_f16_config<DKQ, DV> c;
423
424#ifdef CP_ASYNC_AVAILABLE
425 constexpr int nstages = c::nstages_target;
426#else
427 constexpr int nstages = 0;
428#endif // CP_ASYNC_AVAILABLE
429
430 constexpr int cols_per_warp = ntiles * tile_B::I;
431 constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles;
432 constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
433 constexpr int ncols = ncols1 * ncols2;
434 constexpr int nbatch_K2 = c::get_nbatch_K2_device(ncols);
435 constexpr int nbatch_V2 = c::get_nbatch_V2_device(ncols);
436
437 constexpr int stride_tile_Q = DKQ/2 + 4;
438 constexpr int stride_tile_K = nbatch_K2 + 4;
439
440 static_assert(!mla || nbatch_K2 >= nbatch_V2, "bad nbatch_K2, nbatch_V2 for MLA");
441 constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4;
442
443 const int k_VKQ_0 = kb0 * c::nbatch_fa;
444 tile_C_KQ KQ_C[c::nbatch_fa/(np*tile_C_KQ::I) * ntiles];
445
446 // Use wide variants of tiles if ntiles >= 2.
447 tile_B_16 * Q_B_16 = (tile_B_16 *) Q_B;
448 tile_C_VKQ_16 * VKQ_C_16 = (tile_C_VKQ_16 *) VKQ_C;
449 tile_C_KQ_16 * KQ_C_16 = (tile_C_KQ_16 *) KQ_C;
450
451 if constexpr (nstages > 1) {
452 static_assert(!mla, "multi-stage loading not implemented for MLA");
453 static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi stage loading");
454 constexpr bool use_cp_async = true;
455 cp_async_wait_all();
456 __syncthreads();
457 flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
458 (V_h2 + int64_t(k_VKQ_0)*stride_V, tile_V, nbatch_V2, stride_V);
459 } else {
460 constexpr bool use_cp_async = nstages == 1;
461 if (ncols2 > 1 || mask_h2) {
462 flash_attn_ext_f16_load_mask<ncols1, nwarps, c::nbatch_fa, use_cp_async>(mask_h2 + k_VKQ_0/2, tile_mask, stride_mask);
463 }
464 }
465
466#pragma unroll
467 for (int k0_start = 0; k0_start < DKQ/2; k0_start += nbatch_K2) {
468 const int k0_stop = k0_start + nbatch_K2 < DKQ/2 ? k0_start + nbatch_K2 : DKQ/2;
469 const int k0_diff = k0_stop - k0_start;
470
471 if (nstages <= 1) {
472 constexpr bool use_cp_async = nstages == 1;
473 flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
474 (K_h2 + int64_t(k_VKQ_0)*stride_K + k0_start, tile_K, k0_diff, stride_K);
475 if (use_cp_async) {
476 cp_async_wait_all();
477 }
478 __syncthreads();
479 }
480
481 // Calculate tile of KQ:
482 if constexpr (c::Q_in_reg) {
483#pragma unroll
484 for (int i_KQ_00 = 0; i_KQ_00 < c::nbatch_fa; i_KQ_00 += np*tile_A::I) {
485 const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*tile_A::I;
486#pragma unroll
487 for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += tile_A::J) {
488 tile_A K_A;
489 load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K);
490 if (ntiles == 1) {
491 mma(KQ_C[i_KQ_00/(np*tile_A::I)], K_A, Q_B[k_KQ_0/tile_A::J]);
492 } else {
493#pragma unroll
494 for (int t = 0; t < ntiles/2; ++t) {
495 // Wide version of KQ_C is column-major => swap A and B.
496 mma(KQ_C_16[i_KQ_00/(np*tile_A::I) * ntiles/2 + t], Q_B_16[k_KQ_0/tile_A::J * ntiles/2 + t], K_A);
497 }
498 }
499 }
500 }
501 } else {
502 static_assert(ntiles == 2, "ntiles != 2 not implemented");
503#pragma unroll
504 for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += tile_A::J) {
505 load_ldmatrix(Q_B_16[0], tile_Q + (threadIdx.y / np)*(tile_B_16::I*stride_tile_Q) + k_KQ_0, stride_tile_Q);
506
507#pragma unroll
508 for (int i_KQ_00 = 0; i_KQ_00 < c::nbatch_fa; i_KQ_00 += np*tile_A::I) {
509 const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*tile_A::I;
510
511 tile_A K_A;
512 load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K);
513
514 // Wide version of KQ_C is column-major => swap A and B.
515 mma(KQ_C_16[i_KQ_00/(np*tile_A::I)], Q_B_16[0], K_A);
516 }
517 }
518 }
519
520 if (nstages <= 1) {
521 __syncthreads(); // Only needed if tile_K == tile_V.
522 }
523 }
524
525 if (use_logit_softcap) {
526 static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size");
527#pragma unroll
528 for (int i = 0; i < c::nbatch_fa/(np*tile_C_KQ::I) * ntiles; ++i) {
529#pragma unroll
530 for (int l = 0; l < tile_C_KQ::ne; ++l) {
531 KQ_C[i].x[l] = logit_softcap*tanhf(KQ_C[i].x[l]);
532 }
533 }
534 }
535
536 float KQ_max_new[cols_per_thread];
537#pragma unroll
538 for (int col = 0; col < cols_per_thread; ++col) {
539 KQ_max_new[col] = KQ_max[col];
540 }
541 float KQ_rowsum_add[cols_per_thread] = {0.0f};
542
543 if (ntiles == 1) {
544 if (ncols2 > 1 || mask_h2) {
545#pragma unroll
546 for (int i00 = 0; i00 < c::nbatch_fa; i00 += np*tile_C_KQ::I) {
547 const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ::I;
548#pragma unroll
549 for (int l = 0; l < tile_C_KQ::ne; ++l) {
550 const int i = i0 + tile_C_KQ::get_i(l);
551 const int j = ((threadIdx.y / np)*tile_C_KQ::J + tile_C_KQ::get_j(l)) / ncols2;
552
553 KQ_C[i00/(np*tile_C_KQ::I)].x[l] += slope *
554 __half2float(((const half *) tile_mask)[j*(c::nbatch_fa + 8) + i]);
555 }
556 }
557 }
558
559 // Calculate softmax for each KQ column using the current max. value.
560 // The divisor is stored in KQ_rowsum and will be applied at the end.
561 static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size");
562#pragma unroll
563 for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ::I); ++k) {
564#pragma unroll
565 for (int l = 0; l < tile_C_KQ::ne; ++l) {
566 KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k].x[l]);
567 }
568 }
569
570 // Values per KQ column are spread across 8 threads, does not need full warp reduce:
571#pragma unroll
572 for (int col = 0; col < cols_per_thread; ++col) {
573#pragma unroll
574 for (int offset = 16; offset >= 4; offset >>= 1) {
575 KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE));
576 }
577 }
578
579 static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size");
580#pragma unroll
581 for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ::I); ++k) {
582#pragma unroll
583 for (int l = 0; l < tile_C_KQ::ne; ++l) {
584 KQ_C[k].x[l] = expf(KQ_C[k].x[l] - KQ_max_new[l % 2]);
585
586 KQ_rowsum_add[l % 2] += KQ_C[k].x[l];
587 }
588 }
589 } else { // ntiles > 1
590 if (ncols2 > 1 || mask_h2) {
591#pragma unroll
592 for (int i00 = 0; i00 < c::nbatch_fa; i00 += np*tile_C_KQ_16::J) {
593 const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ_16::J;
594#pragma unroll
595 for (int t = 0; t < ntiles/2; ++t) {
596#pragma unroll
597 for (int l0 = 0; l0 < tile_C_KQ_16::ne; l0 += 2) {
598 const int i = (i0 + tile_C_KQ_16::get_j(l0)) / 2;
599 const int j = ((threadIdx.y / np)*cols_per_warp + t*tile_C_KQ_16::I + tile_C_KQ_16::get_i(l0)) / ncols2;
600
601 const float2 tmp = __half22float2(tile_mask[j*(c::nbatch_fa/2 + 4) + i]);
602 const int KQ_index = i00/(np*tile_C_KQ_16::J) * ntiles/2 + t;
603 KQ_C_16[KQ_index].x[l0 + 0] += slope*tmp.x;
604 KQ_C_16[KQ_index].x[l0 + 1] += slope*tmp.y;
605 }
606 }
607 }
608 }
609
610 // Calculate softmax for each KQ column using the current max. value.
611 // The divisor is stored in KQ_rowsum and will be applied at the end.
612 static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size");
613#pragma unroll
614 for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ_16::J); ++k) {
615#pragma unroll
616 for (int t = 0; t < ntiles/2; ++t) {
617#pragma unroll
618 for (int l = 0; l < tile_C_KQ_16::ne; ++l) {
619 const int KQ_index = 2*t + (l/2) % 2;
620 KQ_max_new[KQ_index] = fmaxf(KQ_max_new[KQ_index], KQ_C_16[k*ntiles/2 + t].x[l]);
621 }
622 }
623 }
624
625 // Values per KQ column are spread across 4 threads, does not need full warp reduce:
626#pragma unroll
627 for (int col = 0; col < cols_per_thread; ++col) {
628#pragma unroll
629 for (int offset = 2; offset >= 1; offset >>= 1) {
630 KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE));
631 }
632 }
633
634 static_assert(c::nbatch_fa % (np*tile_C_KQ_16::J) == 0, "bad loop size");
635#pragma unroll
636 for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ_16::J); ++k) {
637#pragma unroll
638 for (int t = 0; t < ntiles/2; ++t) {
639#pragma unroll
640 for (int l = 0; l < tile_C_KQ_16::ne; ++l) {
641 const int KQ_index = 2*t + (l/2) % 2;
642
643 KQ_C_16[k*ntiles/2 + t].x[l] = expf(KQ_C_16[k*ntiles/2 + t].x[l] - KQ_max_new[KQ_index]);
644
645 KQ_rowsum_add[KQ_index] += KQ_C_16[k*ntiles/2 + t].x[l];
646 }
647 }
648 }
649 }
650
651 {
652 float KQ_max_scale[cols_per_thread];
653#pragma unroll
654 for (int col = 0; col < cols_per_thread; ++col) {
655 const float KQ_max_diff = KQ_max[col] - KQ_max_new[col];
656 KQ_max_scale[col] = expf(KQ_max_diff);
657 KQ_max[col] = KQ_max_new[col];
658
659 *((uint32_t *) &KQ_max_scale[col]) *= KQ_max_diff >= SOFTMAX_FTZ_THRESHOLD;
660
661 // Scale previous KQ_rowsum to account for a potential increase in KQ_max:
662 KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_rowsum_add[col];
663 }
664
665 if (ntiles == 1) {
666 const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]);
667#pragma unroll
668 for (int i = 0; i < DV/tile_C_VKQ::I; ++i) {
669#pragma unroll
670 for (int l = 0; l < tile_C_VKQ::ne; ++l) {
671 VKQ_C[i].x[l] *= KQ_max_scale_h2;
672 }
673 }
674 } else {
675#pragma unroll
676 for (int col = 0; col < cols_per_thread; ++col) {
677 const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]);
678#pragma unroll
679 for (int i = 0; i < DV/tile_C_VKQ_16::J; ++i) {
680#pragma unroll
681 for (int l0 = 0; l0 < tile_C_VKQ_16::ne; l0 += 2) {
682 VKQ_C_16[i*ntiles/2 + col/2].x[l0 + col % 2] *= KQ_max_scale_h2;
683 }
684 }
685 }
686 }
687 }
688
689 // Convert KQ C tiles into B tiles for VKQ calculation:
690 tile_B B[c::nbatch_fa/(np*2*tile_B::J) * ntiles];
691 tile_B_16 * B_16 = (tile_B_16 *) B;
692 static_assert(c::nbatch_fa % (np*2*tile_B::J) == 0, "bad loop size");
693 if (ntiles == 1) {
694#pragma unroll
695 for (int k = 0; k < c::nbatch_fa/(np*2*tile_B::J); ++k) {
696 B[k] = get_transposed(get_half2(KQ_C[k]));
697 }
698 } else {
699 for (int k = 0; k < c::nbatch_fa/(np*2*tile_B_16::J); ++k) {
700#pragma unroll
701 for (int t = 0; t < ntiles/2; ++t) {
702 B_16[k*ntiles/2 + t] = get_half2(KQ_C_16[k*ntiles/2 + t]);
703 }
704 }
705 }
706
707 if (nstages > 1) {
708 // Preload K tile for next iteration:
709 constexpr bool use_cp_async = true;
710 cp_async_wait_all();
711 __syncthreads();
712 if (!last_iter) {
713 if (ncols2 > 1 || mask_h2) {
714 flash_attn_ext_f16_load_mask<ncols1, nwarps, c::nbatch_fa, use_cp_async>
715 (mask_h2 + (k_VKQ_0 + c::nbatch_fa)/2, tile_mask, stride_mask);
716 }
717 flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
718 (K_h2 + int64_t(k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K);
719 }
720 }
721
722
723 // For MLA K and V have the same data.
724 // Therefore, iterate over V in reverse and re-use the data if possible.
725 static_assert(!mla || nstages <= 1, "combination of MLA and multi-stage loading not implemented");
726 constexpr int reusable_cutoff = mla ? (DKQ - 1) - (DKQ - 1) % (2*nbatch_K2) - (DKQ - DV) : DV;
727#pragma unroll
728 for (int i0_stop = DV; i0_stop > 0; i0_stop -= 2*nbatch_V2) {
729 const int i0_start = i0_stop - 2*nbatch_V2 > 0 ? i0_stop - 2*nbatch_V2 : 0;
730 const int i0_diff = i0_stop - i0_start;
731
732 if (nstages <= 1 && i0_start < reusable_cutoff) {
733 constexpr bool use_cp_async = nstages == 1;
734 flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
735 (V_h2 + int64_t(k_VKQ_0)*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V);
736 if (use_cp_async) {
737 cp_async_wait_all();
738 }
739 __syncthreads();
740 }
741 const half2 * tile_V_i = i0_start < reusable_cutoff ? tile_V : tile_V + (i0_start - reusable_cutoff)/2;
742
743 // Calculate VKQ tile:
744#pragma unroll
745 for (int i_VKQ_0 = i0_start; i_VKQ_0 < i0_stop; i_VKQ_0 += tile_C_VKQ::I) {
746 static_assert((c::nbatch_fa/2) % (np*tile_A::J) == 0, "bad loop size");
747#pragma unroll
748 for (int k00 = 0; k00 < c::nbatch_fa/2; k00 += np*tile_A::J) {
749 const int k0 = k00 + (threadIdx.y % np)*tile_A::J;
750
751 tile_A A;
752 load_ldmatrix_trans(A, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
753 if (ntiles == 1) {
754 mma(VKQ_C[i_VKQ_0/tile_C_VKQ::I], A, B[k00/(np*tile_A::J)]);
755 } else {
756#pragma unroll
757 for (int t = 0; t < ntiles/2; ++t) {
758 // Wide version of VKQ_C is column-major => swap A and B.
759 mma(VKQ_C_16[i_VKQ_0/tile_C_VKQ::I * ntiles/2 + t], B_16[k00/(np*tile_A::J) * ntiles/2 + t], A);
760 }
761 }
762 }
763 }
764
765 if (nstages <= 1) {
766 __syncthreads(); // Only needed if tile_K == tile_V.
767 }
768 }
769#else
770 GGML_UNUSED_VARS(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup,
771 scale, slope, logit_softcap, ne01, ne02,
772 stride_K, stride_V, stride_mask,
773 tile_Q, tile_K, tile_V, tile_mask,
774 Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
775 NO_DEVICE_CODE;
776#endif // TURING_MMA_AVAILABLE
777}
778
779template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup>
780static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
781 const float2 * const __restrict__ Q_f2,
782 const half2 * const __restrict__ K_h2,
783 const half2 * const __restrict__ V_h2,
784 const half2 * const __restrict__ mask_h2,
785 const float * const __restrict__ sinks_f,
786 float2 * const __restrict__ dstk,
787 float2 * const __restrict__ dstk_fixup,
788 const float scale,
789 const float slope,
790 const float logit_softcap,
791 const int ne01,
792 const int ne02,
793 const int stride_Q1,
794 const int stride_Q2,
795 const int stride_K,
796 const int stride_V,
797 const int stride_mask,
798 const int jt,
799 const int kb0_start,
800 const int kb0_stop) {
801#ifdef TURING_MMA_AVAILABLE
802 //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
803
804 typedef fattn_mma_f16_config<DKQ, DV> c;
805
806#ifdef CP_ASYNC_AVAILABLE
807 constexpr int nstages = c::nstages_target;
808#else
809 constexpr int nstages = 0;
810#endif // CP_ASYNC_AVAILABLE
811
812 constexpr int ncols = ncols1 * ncols2;
813 constexpr int cols_per_warp = ntiles * tile_B::I;
814 constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles;
815 constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
816 constexpr int nbatch_K2 = c::get_nbatch_K2_device(ncols);
817 constexpr int nbatch_V2 = c::get_nbatch_V2_device(ncols);
818
819 static_assert(nwarps * (cols_per_warp/ncols2) % ncols1 == 0, "bad nwarps");
820
821 constexpr int stride_tile_Q = DKQ/2 + 4;
822 constexpr int stride_tile_K = nbatch_K2 + 4;
823
824 static_assert(!mla || nbatch_K2 >= nbatch_V2, "bad nbatch_K2, nbatch_V2 for MLA");
825 constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4;
826 constexpr int stride_tile_KV_max = stride_tile_K > stride_tile_V ? stride_tile_K : stride_tile_V;
827
828 extern __shared__ half2 tile_Q[];
829 half2 * tile_K = c::Q_in_reg ? tile_Q : tile_Q + ncols * stride_tile_Q;
830 half2 * tile_V = nstages > 1 ? tile_K + c::nbatch_fa * stride_tile_K : tile_K;
831 half2 * tile_mask = nstages > 1 ? tile_V + c::nbatch_fa * stride_tile_V : tile_V + c::nbatch_fa * stride_tile_KV_max;
832
833 tile_B Q_B[(c::Q_in_reg ? DKQ/(2*tile_B::J) : 1) * ntiles];
834 tile_C_VKQ VKQ_C[DV/tile_C_VKQ::I * ntiles];
835
836 tile_B_16 * Q_B_16 = (tile_B_16 *) Q_B;
837 tile_C_VKQ_16 * VKQ_C_16 = (tile_C_VKQ_16 *) VKQ_C;
838
839 float KQ_rowsum[cols_per_thread] = {0.0f};
840 float KQ_max[cols_per_thread];
841#pragma unroll
842 for (int col = 0; col < cols_per_thread; ++col) {
843 KQ_max[col] = -FLT_MAX/2.0f;
844 }
845
846 // Load Q data into tile_Q, either temporarily or permanently.
847 // Q in registers is faster, but register pressure is the biggest bottleneck.
848 // The loading is done with decreasing granularity for D for better memory bandwidth.
849 const half2 scale_h2 = make_half2(scale, scale);
850#pragma unroll
851 for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
852 const int k0_start = stride_k == WARP_SIZE ? 0 : DKQ/2 - (DKQ/2) % (2*stride_k);
853 const int k0_stop = DKQ/2 - (DKQ/2) % (1*stride_k);
854 const int stride_jc = WARP_SIZE / stride_k;
855
856 if (k0_start == k0_stop) {
857 continue;
858 }
859
860#pragma unroll
861 for (int jc0 = 0; jc0 < ncols; jc0 += nwarps*stride_jc) {
862 const int jc = jc0 + threadIdx.y*stride_jc + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
863
864 if (jc0 + nwarps*stride_jc > ncols && jc >= ncols) {
865 break;
866 }
867
868 const int j = jc / ncols2;
869 const int c = jc % ncols2;
870
871 if (jt*ncols1 + j < ne01) {
872#pragma unroll
873 for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
874 const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
875
876 const float2 tmp = Q_f2[(jt*ncols1 + j)*stride_Q1 + c*stride_Q2 + k];
877 tile_Q[jc*stride_tile_Q + k] = scale_h2 * make_half2(tmp.x, tmp.y);
878 }
879 } else {
880#pragma unroll
881 for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
882 const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
883
884 tile_Q[jc*stride_tile_Q + k] = make_half2(0.0f, 0.0f);
885 }
886 }
887 }
888 }
889
890 __syncthreads();
891
892 if (c::Q_in_reg) {
893 const int j0 = (threadIdx.y / np) * cols_per_warp;
894
895#pragma unroll
896 for (int k0 = 0; k0 < DKQ/2; k0 += tile_B::J) {
897 if (ntiles == 1) {
898 load_ldmatrix(Q_B[k0/tile_B::J], tile_Q + j0*stride_tile_Q + k0, stride_tile_Q);
899 } else {
900#pragma unroll
901 for (int t = 0; t < ntiles/2; ++t) {
902 load_ldmatrix(Q_B_16[k0/tile_B_16::J * ntiles/2 + t],
903 tile_Q + (j0 + t*tile_B_16::I)*stride_tile_Q + k0, stride_tile_Q);
904 }
905 }
906 }
907 }
908
909 __syncthreads();
910
911 // Preload mask and K data for first iteration when using cp_async with multiple stages:
912 if constexpr (nstages > 1) {
913 static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi-stage pipeline");
914 constexpr bool use_cp_async = true;
915 if (ncols2 > 1 || mask_h2) {
916 flash_attn_ext_f16_load_mask<ncols1, nwarps, c::nbatch_fa, use_cp_async>
917 (mask_h2 + kb0_start*c::nbatch_fa/2, tile_mask, stride_mask);
918 }
919 flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
920 (K_h2 + int64_t(kb0_start)*c::nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K);
921 }
922
923 // Iterate over ne11 == previous tokens:
924 int kb0 = kb0_start;
925 for (; kb0 < kb0_stop-1; ++kb0) {
926 constexpr bool last_iter = false;
927 flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
928 (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
929 ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
930 }
931 { // kb0_start is always < kb0_stop so the last iter can be executed unconditionally.
932 constexpr bool last_iter = true;
933 flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
934 (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
935 ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
936 }
937
938 // With multi-stage loading there is no __syncthreads at the end of the iter,
939 // there can be a race condition on shared memory access for combining/writing back results.
940 if (nstages > 1 && nwarps*cols_per_warp > c::nbatch_fa) {
941 __syncthreads();
942 }
943
944 // Finally, sum up partial KQ rowsums.
945 // The partial sums are spread across 8/4 threads each, does not need full reduce.
946 {
947 constexpr int offset_first = ntiles == 1 ? 16 : 2;
948 constexpr int offset_last = ntiles == 1 ? 4 : 1;
949#pragma unroll
950 for (int col = 0; col < cols_per_thread; ++col) {
951#pragma unroll
952 for (int offset = offset_first; offset >= offset_last; offset >>= 1) {
953 KQ_rowsum[col] += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum[col], offset, WARP_SIZE);
954 }
955 }
956 }
957
958 // If attention sinks are used, potentially re-scale if KQ_max is small.
959 // Also add the sink as a value to KQ_rowsum, this is done after synchonization of KQ_rowsum
960 // so it's being done unconditionally for every thread.
961 if (!is_fixup && (np == 1 || threadIdx.y % np == 0) && sinks_f) {
962 float KQ_max_scale[cols_per_thread];
963#pragma unroll
964 for (int col = 0; col < cols_per_thread; ++col) {
965 static_assert(ntiles == 1 || ntiles == 2, "ntiles > 2 not implemented");
966 const int jc = ntiles == 1 ? 2*tile_C_VKQ::get_j(col/2) + col % 2 : tile_C_VKQ_16::get_i(col);
967 const float sink = sinks_f[jc % ncols2];
968
969 const float KQ_max_new = fmaxf(KQ_max[col], sink);
970 const float KQ_max_diff = KQ_max[col] - KQ_max_new;
971 KQ_max_scale[col] = expf(KQ_max_diff);
972 KQ_max[col] = KQ_max_new;
973
974 *((uint32_t *) &KQ_max_scale[col]) *= KQ_max_diff >= SOFTMAX_FTZ_THRESHOLD;
975
976 const float KQ_max_add = expf(sink - KQ_max_new);
977 KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_max_add;
978 }
979
980 if (ntiles == 1) {
981 const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]);
982#pragma unroll
983 for (int i = 0; i < DV/tile_C_VKQ::I; ++i) {
984#pragma unroll
985 for (int l = 0; l < tile_C_VKQ::ne; ++l) {
986 VKQ_C[i].x[l] *= KQ_max_scale_h2;
987 }
988 }
989 } else {
990#pragma unroll
991 for (int col = 0; col < cols_per_thread; ++col) {
992 const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]);
993#pragma unroll
994 for (int i = 0; i < DV/tile_C_VKQ_16::J; ++i) {
995#pragma unroll
996 for (int l0 = 0; l0 < tile_C_VKQ_16::ne; l0 += 2) {
997 VKQ_C_16[i*ntiles/2 + col/2].x[l0 + col % 2] *= KQ_max_scale_h2;
998 }
999 }
1000 }
1001 }
1002 }
1003
1004 // Combine VKQ accumulator values if np > 1.
1005 // It's also faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM.
1006 // So also write VKQ accumulators to shared memory in column-major format if np == 1.
1007
1008 constexpr int nbatch_combine = c::get_nbatch_combine_device(ncols);
1009 constexpr int tile_stride = nbatch_combine + 4;
1010 static_assert((DV/2) % nbatch_combine == 0, "bad nbatch_combine");
1011
1012 if constexpr (ntiles == 1) {
1013 const int jc_cwmo = (threadIdx.x % (2*tile_C_VKQ::J)) / tile_C_VKQ::J; // jc combine write meta offset
1014 const int jc_cwm = threadIdx.y*(2*tile_C_VKQ::J) + 2*tile_C_VKQ::get_j(-1) + jc_cwmo; // jc combine write meta
1015 const float2 KQ_cmr = make_float2(KQ_max[jc_cwmo], KQ_rowsum[jc_cwmo]); // KQ combine max rowsum
1016
1017 if (((!needs_fixup && !is_fixup) || np > 1) && threadIdx.x < 2*tile_C_VKQ::J) {
1018 // Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale.
1019 ((float2 *) tile_Q)[jc_cwm*(tile_stride/2) + nbatch_combine/2] = KQ_cmr;
1020 }
1021
1022 __syncthreads();
1023
1024 if (np == 1) {
1025 // No combination is needed, the meta data can be directly written from registers to VRAM.
1026 if (needs_fixup && threadIdx.x < tile_B::I) {
1027 float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
1028 dstk_fixup_meta[jc_cwm] = KQ_cmr;
1029 }
1030 if (is_fixup && threadIdx.x < tile_B::I) {
1031 float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
1032 dstk_fixup_meta[jc_cwm] = KQ_cmr;
1033 }
1034 }
1035 } else {
1036 static_assert(ntiles == 2 || ntiles == 4, "bad ntiles");
1037 const int jc_cwm = threadIdx.y*cols_per_warp // jc combine write meta
1038 + (ntiles == 4 ? ((threadIdx.x % 4) / 2) * tile_C_VKQ_16::I : 0)
1039 + tile_C_VKQ_16::get_i(threadIdx.x % 4);
1040 const float2 KQ_cmr = make_float2(KQ_max[threadIdx.x % cols_per_thread], KQ_rowsum[threadIdx.x % cols_per_thread]); // KQ combine max rowsum
1041
1042 if (((!needs_fixup && !is_fixup) || np > 1) && (ntiles == 4 || threadIdx.x % 4 < cols_per_thread)) {
1043 // Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale.
1044 ((float2 *) tile_Q)[jc_cwm*(tile_stride/2) + nbatch_combine/2] = KQ_cmr;
1045 }
1046
1047 __syncthreads();
1048
1049 if (np == 1) {
1050 // No combination is needed, the meta data can be directly written from registers to VRAM.
1051 if (needs_fixup && (ntiles == 4 || threadIdx.x % 4 < ntiles)) {
1052 float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
1053 dstk_fixup_meta[jc_cwm] = KQ_cmr;
1054 }
1055 if (is_fixup && (ntiles == 4 || threadIdx.x % 4 < ntiles)) {
1056 float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
1057 dstk_fixup_meta[jc_cwm] = KQ_cmr;
1058 }
1059 }
1060 }
1061
1062 static_assert(np == 1 || ntiles == 1 || ntiles == 2, "bad ntiles");
1063 if (np > 1 && threadIdx.y % np == 0) {
1064 // Combine the meta data for parallel warps via shared memory.
1065 // Warps with threadIdx.y % np != 0 must NOT return early.
1066 // All threads must return simultaneously to avoid race conditions with work on the next tile.
1067
1068 constexpr int nmeta = np*cols_per_warp >= WARP_SIZE ? np*cols_per_warp/WARP_SIZE : 1;
1069
1070 const int jc_meta = threadIdx.y*cols_per_warp + (np*cols_per_warp < WARP_SIZE ? threadIdx.x % (np*cols_per_warp) : threadIdx.x);
1071 float2 * const meta_ptr = ((float2 *) tile_Q) + jc_meta*(tile_stride/2) + nbatch_combine/2;
1072 float2 meta[nmeta];
1073#pragma unroll
1074 for (int imeta = 0; imeta < nmeta; ++imeta) {
1075 meta[imeta] = meta_ptr[imeta * WARP_SIZE * tile_stride/2];
1076 }
1077
1078 float KQ_cmn = meta[0].x; // KQ combine max new, max between all parallel warps.
1079#pragma unroll
1080 for (int imeta = 1; imeta < nmeta; ++imeta) {
1081 KQ_cmn = fmaxf(KQ_cmn, meta[imeta].x);
1082 }
1083#pragma unroll
1084 for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) {
1085 if (offset < WARP_SIZE) {
1086 KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, WARP_SIZE));
1087 }
1088 }
1089
1090 float KQ_cms[nmeta]; // KQ combine max scale per warp.
1091#pragma unroll
1092 for (int imeta = 0; imeta < nmeta; ++imeta) {
1093 KQ_cms[imeta] = expf(meta[imeta].x - KQ_cmn);
1094 }
1095
1096 float KQ_crs = KQ_cms[0]*meta[0].y; // KQ combine rowsum, scaled sum of all parallel warps.
1097#pragma unroll
1098 for (int imeta = 1; imeta < nmeta; ++imeta) {
1099 KQ_crs += KQ_cms[imeta]*meta[imeta].y;
1100 }
1101#pragma unroll
1102 for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) {
1103 if (offset < WARP_SIZE) {
1104 KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE);
1105 }
1106 }
1107
1108 __syncthreads();
1109
1110 // Write back combined meta data:
1111#pragma unroll
1112 for (int imeta = 0; imeta < nmeta; ++imeta) {
1113 if (np*cols_per_warp >= WARP_SIZE || threadIdx.x < np*cols_per_warp) {
1114 // Combined KQ max scale + rowsum.
1115 meta_ptr[imeta * WARP_SIZE * tile_stride/2] = make_float2(KQ_cms[imeta], KQ_crs);
1116 }
1117 }
1118
1119 // Combined KQ max + rowsum.
1120 static_assert(cols_per_warp <= WARP_SIZE);
1121 if (needs_fixup && (cols_per_warp == WARP_SIZE || threadIdx.x < cols_per_warp)) {
1122 float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
1123 dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
1124 }
1125 if (is_fixup && (cols_per_warp == WARP_SIZE || threadIdx.x < cols_per_warp)) {
1126 float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
1127 dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
1128 }
1129 } else if (np > 1) {
1130 // Warps with threadIdx.y % np == 0 execute a __syncthreads() in the if branch.
1131 // Therefore, all other warps also need to execute a __syncthreads().
1132 // Otherwise the points at which warps synchronize with each other would become misaligned.
1133 __syncthreads();
1134 }
1135
1136#pragma unroll
1137 for (int k00 = 0; k00 < DV/2; k00 += nbatch_combine) {
1138 if (ntiles == 1) {
1139 const int jc_cwd = threadIdx.y*tile_B::I + tile_B::get_i(-1); // jc combine write data
1140#pragma unroll
1141 for (int k0 = 0; k0 < nbatch_combine; k0 += tile_B::J) {
1142 const tile_B B = get_transposed(VKQ_C[(k00 + k0)/tile_B::J]); // Conversion of C to B matrix puts it in column-major format.
1143
1144#pragma unroll
1145 for (int l = 0; l < tile_B::ne; ++l) {
1146 const int k = k0 + tile_B::get_j(l);
1147
1148 tile_Q[jc_cwd*tile_stride + k] = B.x[l];
1149 }
1150 }
1151 } else {
1152#pragma unroll
1153 for (int t = 0; t < ntiles/2; ++t) {
1154 const int j0 = threadIdx.y*cols_per_warp + t*tile_C_VKQ_16::I;
1155#pragma unroll
1156 for (int k0 = 0; k0 < nbatch_combine; k0 += tile_C_VKQ_16::J) {
1157#pragma unroll
1158 for (int l = 0; l < tile_C_VKQ_16::ne; ++l) {
1159 const int j = j0 + tile_C_VKQ_16::get_i(l);
1160 const int k = k0 + tile_C_VKQ_16::get_j(l);
1161
1162 tile_Q[j*tile_stride + k] = VKQ_C_16[(k00 + k0)/tile_C_VKQ_16::J * ntiles/2 + t].x[l];
1163 }
1164 }
1165 }
1166 }
1167
1168 __syncthreads();
1169
1170 if (np == 1 || threadIdx.y % np == 0) {
1171 // The first 2*2*gridDim.x*ncols floats in dstk_fixup are for storing max. values and row sums.
1172 // The values after that are for the partial results of the individual blocks.
1173 float2 * dstk_fixup_data = dstk_fixup + gridDim.x*(2*ncols) + blockIdx.x*(ncols*(DV/2));
1174
1175#pragma unroll
1176 for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
1177 const int k0_start = stride_k == WARP_SIZE ? 0 : nbatch_combine - nbatch_combine % (2*stride_k);
1178 const int k0_stop = nbatch_combine - nbatch_combine % (1*stride_k);
1179 const int stride_jc = WARP_SIZE / stride_k;
1180
1181 if (k0_start == k0_stop) {
1182 continue;
1183 }
1184
1185#pragma unroll
1186 for (int jc0_dst = 0; jc0_dst < ncols; jc0_dst += (nwarps/np)*stride_jc) {
1187 const int jc_dst = jc0_dst + (threadIdx.y/np)*stride_jc + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
1188
1189 if (jc0_dst + (nwarps/np)*stride_jc > ncols && jc_dst >= ncols) {
1190 break;
1191 }
1192
1193 const int jc_tile_K = (jc_dst/cols_per_warp)*(np*cols_per_warp) + jc_dst % cols_per_warp;
1194
1195 const int j_dst = jc_dst / ncols2;
1196 const int c_dst = jc_dst % ncols2;
1197
1198 if (!is_fixup && jt*ncols1 + j_dst >= ne01) {
1199 continue;
1200 }
1201
1202 const float * meta_j = (const float *) tile_Q + jc_tile_K*tile_stride + nbatch_combine;
1203#pragma unroll
1204 for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
1205 const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
1206
1207 float2 dstk_val = make_float2(0.0f, 0.0f);
1208#pragma unroll
1209 for (int ip = 0; ip < np; ++ip) {
1210 const float KQ_crs = np == 1 ? 1.0f : meta_j[ip*cols_per_warp * tile_stride + 0];
1211 const float2 dstk_val_add = __half22float2(tile_Q[(jc_tile_K + ip*cols_per_warp) * tile_stride + k]);
1212 dstk_val.x += dstk_val_add.x*KQ_crs;
1213 dstk_val.y += dstk_val_add.y*KQ_crs;
1214 }
1215
1216 if (!needs_fixup && !is_fixup) {
1217 const float KQ_rowsum_j = meta_j[1];
1218 dstk_val.x /= KQ_rowsum_j;
1219 dstk_val.y /= KQ_rowsum_j;
1220 }
1221
1222 if (is_fixup) {
1223 dstk_fixup_data[jc_dst*(DV/2) + k00 + k] = dstk_val;
1224 } else {
1225 dstk[((jt*ncols1 + j_dst)*ne02 + c_dst)*(DV/2) + k00 + k] = dstk_val;
1226 }
1227 }
1228 }
1229 }
1230 }
1231 if (np > 1) {
1232 __syncthreads();
1233 }
1234 }
1235#else
1236 GGML_UNUSED_VARS(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dstk_fixup,
1237 scale, slope, logit_softcap, ne01, ne02,
1238 stride_Q1, stride_Q2, stride_K, stride_V, stride_mask,
1239 jt, kb0_start, kb0_stop);
1240 NO_DEVICE_CODE;
1241#endif // TURING_MMA_AVAILABLE
1242}
1243
1244template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla>
1245__launch_bounds__(nwarps*WARP_SIZE, 1)
1246static __global__ void flash_attn_ext_f16(
1247 const char * __restrict__ Q,
1248 const char * __restrict__ K,
1249 const char * __restrict__ V,
1250 const char * __restrict__ mask,
1251 const char * __restrict__ sinks,
1252 const int * __restrict__ KV_max,
1253 float * __restrict__ dst,
1254 float2 * __restrict__ dst_meta,
1255 const float scale,
1256 const float max_bias,
1257 const float m0,
1258 const float m1,
1259 const uint32_t n_head_log2,
1260 const float logit_softcap,
1261 const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
1262 const int32_t nb01, const int32_t nb02, const int32_t nb03,
1263 const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
1264 const int32_t nb11, const int32_t nb12, const int64_t nb13,
1265 const int32_t nb21, const int32_t nb22, const int64_t nb23,
1266 const int32_t ne31, const int32_t ne32, const int32_t ne33,
1267 const int32_t nb31, const int32_t nb32, const int64_t nb33) {
1268#if defined(FLASH_ATTN_AVAILABLE) && defined(TURING_MMA_AVAILABLE)
1269
1270 // Skip unused kernel variants for faster compilation:
1271 if (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) {
1272 NO_DEVICE_CODE;
1273 return;
1274 }
1275#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
1276 if (ncols1*ncols2 > 32) {
1277 NO_DEVICE_CODE;
1278 return;
1279 }
1280#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING
1281
1282 static_assert(!mla || DKQ >= DV, "MLA needs DKQ >= DV");
1283
1284 typedef fattn_mma_f16_config<DKQ, DV> c;
1285
1286 static_assert(FATTN_KQ_STRIDE % fattn_mma_f16_config<DKQ, DV>::nbatch_fa == 0, "bad nbatch_fa");
1287
1288 const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
1289
1290 const int stride_Q1 = nb01 / sizeof(float2);
1291 const int stride_Q2 = nb02 / sizeof(float2);
1292 const int stride_K = nb11 / sizeof(half2);
1293 const int stride_mask = nb31 / sizeof(half2);
1294
1295 const int stride_V = mla ? stride_K : nb21 / sizeof(half2);
1296
1297 const int iter_k = ne11 / FATTN_KQ_STRIDE;
1298 const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
1299
1300 constexpr int kb_niter = FATTN_KQ_STRIDE / c::nbatch_fa; // Number of kernel iterations per assigned KQ slice.
1301
1302 // kbc == k block continuous, current index in continuous ijk space.
1303 int kbc = (blockIdx.x + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
1304 const int kbc_stop = (blockIdx.x + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
1305
1306 // If the seams of 2 CUDA blocks fall within an output tile their results need to be combined.
1307 // For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup).
1308 // In the most general case >2 seams can fall into the same tile.
1309
1310 // kb0 == k start index when in the output tile.
1311 int kb0_start = kbc % iter_k;
1312 int kb0_stop = min(iter_k, kb0_start + kbc_stop - kbc);
1313
1314 while (kbc < kbc_stop && kb0_stop == iter_k) {
1315 const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
1316 const int zt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); // head in units of ncols2
1317 const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile.
1318
1319 const int head0 = zt * ncols2;
1320
1321 const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0);
1322 const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
1323 const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr :
1324 (const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1);
1325 float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head0) * (DV/2);
1326
1327 const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
1328 const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr;
1329
1330 const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;
1331
1332 const int kb0_start_kernel = kb0_start * kb_niter;
1333 int kb0_stop_kernel = kb0_stop * kb_niter;
1334
1335 if (KV_max) {
1336 kb0_stop_kernel = min(kb0_stop_kernel, KV_max[sequence*iter_j + jt] / c::nbatch_fa);
1337 }
1338
1339 constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
1340 if (kb0_start == 0) {
1341 constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
1342 flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
1343 (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
1344 ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
1345 } else {
1346 constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile.
1347 flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
1348 (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
1349 ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
1350 }
1351
1352 kbc += iter_k;
1353 kbc -= kbc % iter_k;
1354
1355 kb0_start = 0;
1356 kb0_stop = min(iter_k, kbc_stop - kbc);
1357 }
1358
1359 if (kbc >= kbc_stop) {
1360 return;
1361 }
1362
1363 const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
1364 const int zt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); // head in units of ncols2
1365 const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile.
1366
1367 const int head0 = zt * ncols2;
1368
1369 const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0);
1370 const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
1371 const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr :
1372 (const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1);
1373 float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head0) * (DV/2);
1374
1375 const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
1376 const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr;
1377
1378 const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;
1379
1380 const int kb0_start_kernel = kb0_start * kb_niter;
1381 int kb0_stop_kernel = kb0_stop * kb_niter;
1382
1383 if (KV_max) {
1384 kb0_stop_kernel = min(kb0_stop_kernel, KV_max[sequence*iter_j + jt] / c::nbatch_fa);
1385 }
1386
1387 constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
1388 constexpr bool needs_fixup = false;
1389 flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
1390 (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
1391 ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
1392#else
1393 GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
1394 max_bias, m0, m1, n_head_log2, logit_softcap,
1395 ne00, ne01, ne02, ne03,
1396 nb01, nb02, nb03,
1397 ne10, ne11, ne12, ne13,
1398 nb11, nb12, nb13,
1399 nb21, nb22, nb23,
1400 ne31, ne32, ne33,
1401 nb31, nb32, nb33);
1402 NO_DEVICE_CODE;
1403#endif // defined(FLASH_ATTN_AVAILABLE) && defined(TURING_MMA_AVAILABLE)
1404}
1405
1406template <int DKQ, int DV, int ncols1, int ncols2>
1407void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
1408 const ggml_tensor * KQV = dst;
1409 const int id = ggml_cuda_get_device();
1410 const int cc = ggml_cuda_info().devices[id].cc;
1411
1412 typedef fattn_mma_f16_config<DKQ, DV> c;
1413
1414 const int nstages = cp_async_available(cc) ? c::nstages_target : 0;
1415
1416 constexpr int ncols = ncols1 * ncols2;
1417 constexpr int ntiles = ncols <= 8 ? 1 : 2; // Number of tiles per warp.
1418 constexpr int cols_per_warp = ntiles * tile_B::I;
1419 constexpr int nwarps_max_x = ncols / cols_per_warp;
1420 constexpr int nwarps_max_y = c::nbatch_fa / tile_A::I;
1421 constexpr int nwarps = nwarps_max_x*nwarps_max_y <= c::nwarps_max ? nwarps_max_x*nwarps_max_y : c::nwarps_max;
1422
1423 constexpr bool mla = DKQ == 576;
1424
1425 const int nbatch_K2 = c::get_nbatch_K2_host (cc, ncols);
1426 const int nbatch_V2 = c::get_nbatch_K2_host (cc, ncols);
1427 const int nbatch_combine = c::get_nbatch_combine_host(cc, ncols);
1428
1429 static_assert(DKQ % tile_B::J == 0, "bad DKQ");
1430 static_assert(DV % tile_A::J == 0, "bad DV");
1431 static_assert(ncols % cols_per_warp == 0, "bad ncols");
1432
1433 const size_t nbytes_shared_KV_1stage = c::nbatch_fa * std::max(a: nbatch_K2 + 4, b: nbatch_V2 + 4) * sizeof(half2);
1434 const size_t nbytes_shared_KV_2stage = c::nbatch_fa * (nbatch_K2 + 4 + nbatch_V2 + 4) * sizeof(half2);
1435 const size_t nbytes_shared_Q = ncols * (DKQ/2 + 4) * sizeof(half2);
1436 const size_t nbytes_shared_mask = ncols1 * (c::nbatch_fa/2 + 4) * sizeof(half2);
1437 const size_t nbytes_shared_combine = nwarps*cols_per_warp * (nbatch_combine + 4) * sizeof(half2);
1438
1439 const size_t nbytes_shared_KV = nstages <= 1 ? nbytes_shared_KV_1stage : nbytes_shared_KV_2stage;
1440
1441 const size_t nbytes_shared_total = std::max(nbytes_shared_combine, c::Q_in_reg ?
1442 std::max(a: nbytes_shared_Q, b: nbytes_shared_KV + nbytes_shared_mask) :
1443 nbytes_shared_Q + nbytes_shared_KV + nbytes_shared_mask);
1444
1445 float logit_softcap;
1446 memcpy(dest: &logit_softcap, src: (const float *) KQV->op_params + 2, n: sizeof(float));
1447
1448 fattn_kernel_t fattn_kernel;
1449 if (logit_softcap == 0.0f) {
1450 constexpr bool use_logit_softcap = false;
1451 fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla>;
1452
1453#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
1454 static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
1455 if (!shared_memory_limit_raised[id]) {
1456 CUDA_CHECK(cudaFuncSetAttribute(fattn_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total));
1457 shared_memory_limit_raised[id] = true;
1458 }
1459#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
1460 } else {
1461 constexpr bool use_logit_softcap = true;
1462 fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla>;
1463
1464#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
1465 static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
1466 if (!shared_memory_limit_raised[id]) {
1467 CUDA_CHECK(cudaFuncSetAttribute(fattn_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total));
1468 shared_memory_limit_raised[id] = true;
1469 }
1470#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
1471 }
1472
1473 launch_fattn<DV, ncols1, ncols2>
1474 (ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, FATTN_KQ_STRIDE, true, true, true);
1475}
1476
1477
1478#define DECL_FATTN_MMA_F16_CASE(DKQ, DV, ncols1, ncols2) \
1479 template void ggml_cuda_flash_attn_ext_mma_f16_case \
1480 <DKQ, DV, ncols1, ncols2>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \
1481
1482#define DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(DKQ, DV, ncols) \
1483 extern DECL_FATTN_MMA_F16_CASE(DKQ, DV, (ncols)/ 1, 1); \
1484 extern DECL_FATTN_MMA_F16_CASE(DKQ, DV, (ncols)/ 2, 2); \
1485 extern DECL_FATTN_MMA_F16_CASE(DKQ, DV, (ncols)/ 4, 4); \
1486 extern DECL_FATTN_MMA_F16_CASE(DKQ, DV, (ncols)/ 8, 8); \
1487 extern DECL_FATTN_MMA_F16_CASE(DKQ, DV, (ncols)/16, 16); \
1488
1489DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 64, 8)
1490DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 80, 8)
1491DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 96, 8)
1492DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 112, 8)
1493DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128, 8)
1494DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 8)
1495
1496DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 64, 16)
1497DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 80, 16)
1498DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 96, 16)
1499DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 112, 16)
1500DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128, 16)
1501DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 16)
1502
1503DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 64, 32)
1504DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 80, 32)
1505DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 96, 32)
1506DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 112, 32)
1507DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128, 32)
1508DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 32)
1509
1510DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 64, 64)
1511DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 80, 64)
1512DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 96, 64)
1513DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 112, 64)
1514DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128, 64)
1515DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 64)
1516
1517// The number of viable configurations for Deepseek is very limited:
1518extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16);
1519extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16);
1520extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16);
1521