1// Old and deprecated WMMA FlashAttention implementation.
2// It is still needed for Volta since the memory layout of NVIDIA tensor cores changed with Turing.
3// Long-term the WMMA code should be replaced with a dedicated Volta implementation.
4
5#include "common.cuh"
6#include "fattn-common.cuh"
7#include "fattn-wmma-f16.cuh"
8
9#ifdef GGML_USE_WMMA_FATTN
10#if !defined(GGML_USE_HIP)
11#include <mma.h>
12#if defined(GGML_USE_MUSA)
13namespace wmma = mtmusa::wmma;
14#else // GGML_USE_MUSA
15namespace wmma = nvcuda::wmma;
16#endif // GGML_USE_MUSA
17#elif defined(GGML_USE_HIP)
18#include <rocwmma/rocwmma.hpp>
19namespace wmma = rocwmma;
20#endif // !defined(GGML_USE_HIP)
21#endif // GGML_USE_WMMA_FATTN
22
23// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
24template<int D, int ncols, int nwarps, int VKQ_stride, typename KQ_acc_t, bool use_logit_softcap>
25__launch_bounds__(nwarps*ggml_cuda_get_physical_warp_size(), 1)
26static __global__ void flash_attn_ext_f16(
27 const char * __restrict__ Q,
28 const char * __restrict__ K,
29 const char * __restrict__ V,
30 const char * __restrict__ mask,
31 const char * __restrict__ sinks,
32 const int * __restrict__ KV_max,
33 float * __restrict__ dst,
34 float2 * __restrict__ dst_meta,
35 const float scale,
36 const float max_bias,
37 const float m0,
38 const float m1,
39 const uint32_t n_head_log2,
40 const float logit_softcap,
41 const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
42 const int32_t nb01, const int32_t nb02, const int32_t nb03,
43 const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
44 const int32_t nb11, const int32_t nb12, const int64_t nb13,
45 const int32_t nb21, const int32_t nb22, const int64_t nb23,
46 const int32_t ne31, const int32_t ne32, const int32_t ne33,
47 const int32_t nb31, const int32_t nb32, const int64_t nb33) {
48#if defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_USE_WMMA_FATTN)))
49 // Skip unused kernel variants for faster compilation:
50 if (use_logit_softcap && !(D == 128 || D == 256)) {
51 NO_DEVICE_CODE;
52 return;
53 }
54
55 //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
56
57 constexpr int warp_size = ggml_cuda_get_physical_warp_size();
58
59 const int ic0 = ncols*blockIdx.x; // Index of the first Q/QKV column to work on.
60
61 static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE.");
62 static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16.");
63 constexpr int frag_m = ncols == 8 ? 32 : 16;
64 constexpr int frag_n = ncols == 8 ? 8 : 16;
65 static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0.");
66 typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16, half, wmma::row_major> frag_a_K;
67 typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16, half, wmma::col_major> frag_a_V;
68 typedef wmma::fragment<wmma::matrix_b, frag_m, frag_n, 16, half, wmma::col_major> frag_b;
69 typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t> frag_c_KQ;
70 typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, half> frag_c_VKQ;
71
72 constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel.
73 constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy.
74 static_assert(VKQ_ratio <= nwarps, "VKQ_ratio must be <= nwarps.");
75
76 // Pad internal representation of KQ, KQV to reduce shared memory bank conflicts:
77 constexpr int D_padded = D + 8;
78 constexpr int kqs_padded = FATTN_KQ_STRIDE + 8;
79 constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half);
80
81 const int sequence = blockIdx.z / ne02;
82 const int head = blockIdx.z - sequence*ne02;
83 const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
84 const float * Q_f = (const float *) (Q + nb03* sequence + nb02* head + nb01*ic0);
85 const half * K_h = (const half *) (K + nb13* sequence + nb12*(head / gqa_ratio));
86 const half * V_h = (const half *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
87 const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
88 const half2 * mask2 = (const half2 *) maskh;
89 const float * sinksf = (const float *) sinks;
90
91 const int stride_Q = nb01 / sizeof(float);
92 const int stride_KV = nb11 / sizeof(half);
93
94 const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
95 const half slopeh = __float2half(slopef);
96 const half2 slope2 = make_half2(slopef, slopef);
97
98 const half2 logit_softcap_2 = make_half2(logit_softcap, logit_softcap);
99
100 frag_b Q_b[D/16][ncols/frag_n];
101
102 // A single buffer for temporarily holding tiles of KQ and VKQ parts:
103 constexpr int mem_KQ = ncols*kqs_padded*kqar;
104 constexpr int mem_VKQ_parts = VKQ_ratio*ncols*D_padded;
105 __shared__ half KQ[mem_KQ >= mem_VKQ_parts ? mem_KQ : mem_VKQ_parts];
106 float * KQ_f = (float *) KQ;
107 half2 * KQ2 = (half2 *) KQ;
108
109 float KQ_rowsum_f[ncols/nwarps] = {0.0f};
110 float KQ_max_f[ncols/nwarps];
111 float KQ_max_scale_f[ncols/nwarps] = {0.0f};
112
113#pragma unroll
114 for (int j = 0; j < ncols/nwarps; ++j) {
115 KQ_max_f[j] = -FLT_MAX/2.0f;
116 }
117
118 half2 KQ_rowsum_h2[ncols/nwarps] = {{0.0f, 0.0f}};
119 half2 KQ_max_h2[ncols/nwarps];
120 half2 KQ_max_scale_h2[ncols/nwarps] = {{0.0f, 0.0f}};
121
122#pragma unroll
123 for (int j = 0; j < ncols/nwarps; ++j) {
124 KQ_max_h2[j] = make_half2(-HALF_MAX_HALF, -HALF_MAX_HALF);
125 }
126
127 __shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice.
128 half2 * VKQ2 = (half2 *) VKQ;
129#pragma unroll
130 for (int j0 = 0; j0 < ncols; j0 += nwarps) {
131 const int j = j0 + threadIdx.y;
132#pragma unroll
133 for (int i0 = 0; i0 < D/2; i0 += warp_size) {
134 const int i = i0 + threadIdx.x;
135 if (i0 + warp_size > D/2 && i >= D/2) {
136 break;
137 }
138 VKQ2[j*(D_padded/2) + i] = make_half2(0.0f, 0.0f);
139 }
140 }
141
142 // Convert Q to half and apply scale, temporarily store in KQ:
143#pragma unroll
144 for (int j0 = 0; j0 < ncols; j0 += nwarps) {
145 const int j = j0 + threadIdx.y;
146#pragma unroll
147 for (int i0 = 0; i0 < D; i0 += warp_size) {
148 const int i = i0 + threadIdx.x;
149 if (i0 + warp_size > D && i >= D) {
150 break;
151 }
152 KQ[j*D_padded + i] = ic0 + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f;
153 }
154 }
155
156 __syncthreads();
157
158 // Load Q into tensor core fragments/registers since it will be used frequently:
159#pragma unroll
160 for (int i0 = 0; i0 < D; i0 += 16) {
161#pragma unroll
162 for (int j0 = 0; j0 < ncols; j0 += frag_n) {
163 wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
164 }
165 }
166
167 __syncthreads();
168
169 // Iterate over ne11 == previous tokens:
170 const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
171 for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE) {
172 // Calculate tile of KQ:
173#pragma unroll
174 for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) {
175 frag_c_KQ KQ_c[ncols/frag_n];
176#pragma unroll
177 for (int j = 0; j < ncols/frag_n; ++j) {
178 wmma::fill_fragment(KQ_c[j], static_cast<KQ_acc_t>(0.0f));
179 }
180#pragma unroll
181 for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
182 frag_a_K K_a;
183 wmma::load_matrix_sync(K_a, K_h + int64_t(k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
184#pragma unroll
185 for (int j = 0; j < ncols/frag_n; ++j) {
186 wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
187 }
188 }
189#pragma unroll
190 for (int j0 = 0; j0 < ncols; j0 += frag_n) {
191 wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, wmma::mem_col_major);
192 }
193 }
194
195 __syncthreads();
196
197 // Calculate softmax for each KQ column using the current max. value.
198 // The divisor is stored in KQ_rowsum and will be applied at the end.
199#pragma unroll
200 for (int j0 = 0; j0 < ncols; j0 += nwarps) {
201 const int j = j0 + threadIdx.y;
202
203 if (std::is_same<KQ_acc_t, float>::value) {
204 float KQ_f_tmp[FATTN_KQ_STRIDE / warp_size];
205#pragma unroll
206 for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += warp_size) {
207 const int k = k0 + threadIdx.x;
208
209 KQ_f_tmp[k0/warp_size] = KQ_f[j*kqs_padded + k];
210
211 if (use_logit_softcap) {
212 KQ_f_tmp[k0/warp_size] = logit_softcap*tanhf(KQ_f_tmp[k0/warp_size]);
213 }
214 }
215
216 float KQ_max_new = KQ_max_f[j0/nwarps];
217#pragma unroll
218 for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += warp_size) {
219 const int k = k0 + threadIdx.x;
220
221 KQ_f_tmp[k0/warp_size] += mask ? __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f;
222 KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/warp_size]);
223 }
224 KQ_max_new = warp_reduce_max<warp_size>(KQ_max_new);
225
226 const float diff = KQ_max_f[j0/nwarps] - KQ_max_new;
227 KQ_max_scale_f[j0/nwarps] = expf(diff);
228 if (diff <= SOFTMAX_FTZ_THRESHOLD) {
229 KQ_max_scale_f[j0/nwarps] = 0.0f;
230 }
231 KQ_max_f[j0/nwarps] = KQ_max_new;
232
233 float KQ_rowsum_add = 0.0f;
234#pragma unroll
235 for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += warp_size) {
236 const int k = k0 + threadIdx.x;
237
238 const float diff = KQ_f_tmp[k0/warp_size] - KQ_max_f[j0/nwarps];
239 KQ_f_tmp[k0/warp_size] = expf(diff);
240 if (diff <= SOFTMAX_FTZ_THRESHOLD) {
241 KQ_f_tmp[k0/warp_size] = 0.0f;
242 }
243 KQ_rowsum_add += KQ_f_tmp[k0/warp_size];
244 KQ[j*(kqar*kqs_padded) + k] = KQ_f_tmp[k0/warp_size];
245 }
246 KQ_rowsum_add = warp_reduce_sum<warp_size>(KQ_rowsum_add);
247
248 // Scale previous KQ_rowsum to account for a potential increase in KQ_max:
249 KQ_rowsum_f[j0/nwarps] = KQ_max_scale_f[j0/nwarps]*KQ_rowsum_f[j0/nwarps] + KQ_rowsum_add;
250 } else {
251 half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*warp_size)];
252#pragma unroll
253 for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += warp_size) {
254 const int k = k0 + threadIdx.x;
255
256 KQ2_tmp[k0/warp_size] = KQ2[j*(kqs_padded/2) + k];
257
258 if (use_logit_softcap) {
259 // There is no dedicated tangens hyperbolicus function for half2.
260 KQ2_tmp[k0/warp_size] = h2exp(KQ2_tmp[k0/warp_size]*make_half2(2.0f, 2.0f));
261 KQ2_tmp[k0/warp_size] = (KQ2_tmp[k0/warp_size] - make_half2(1.0f, 1.0f))
262 /(KQ2_tmp[k0/warp_size] + make_half2(1.0f, 1.0f));
263
264 KQ2_tmp[k0/warp_size] *= logit_softcap_2;
265 }
266 }
267
268 half2 KQ_max_new = KQ_max_h2[j0/nwarps];
269#pragma unroll
270 for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += warp_size) {
271 const int k = k0 + threadIdx.x;
272
273 KQ2_tmp[k0/warp_size] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
274 KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/warp_size]);
275 }
276 KQ_max_new = __half2half2(warp_reduce_max<warp_size>(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new))));
277 const half2 diff = KQ_max_h2[j0/nwarps] - KQ_max_new;
278 KQ_max_scale_h2[j0/nwarps] = h2exp(diff);
279 const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
280 *((uint32_t *) &KQ_max_scale_h2[j0/nwarps]) &= ftz_mask;
281 KQ_max_h2[j0/nwarps] = KQ_max_new;
282
283 half2 KQ_rowsum_add = make_half2(0.0f, 0.0f);
284#pragma unroll
285 for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += warp_size) {
286 const int k = k0 + threadIdx.x;
287
288 const half2 diff = KQ2_tmp[k0/warp_size] - KQ_max_h2[j0/nwarps];
289 KQ2_tmp[k0/warp_size] = h2exp(diff);
290 const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
291 *((uint32_t *) &KQ2_tmp[k0/warp_size]) &= ftz_mask;
292 KQ_rowsum_add += KQ2_tmp[k0/warp_size];
293 KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/warp_size];
294 }
295 KQ_rowsum_add = warp_reduce_sum<warp_size>(KQ_rowsum_add);
296
297 // Scale previous KQ_rowsum to account for a potential increase in KQ_max:
298 KQ_rowsum_h2[j0/nwarps] = KQ_max_scale_h2[j0/nwarps]*KQ_rowsum_h2[j0/nwarps] + KQ_rowsum_add;
299 }
300 }
301
302 __syncthreads();
303
304 frag_b KQ_b[FATTN_KQ_STRIDE/(VKQ_ratio*16)][ncols/frag_n];
305#pragma unroll
306 for (int j0 = 0; j0 < ncols; j0 += frag_n) {
307#pragma unroll
308 for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
309 const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
310 wmma::load_matrix_sync(
311 KQ_b[k0/(VKQ_ratio*16)][j0/frag_n],
312 KQ + j0*(kqar*kqs_padded) + k,
313 kqar*kqs_padded);
314 }
315 }
316
317 frag_c_VKQ VKQ_c[D/VKQ_stride][ncols/frag_n];
318#pragma unroll
319 for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) {
320#pragma unroll
321 for (int j = 0; j < ncols/frag_n; ++j) {
322 wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], static_cast<half>(0.0f));
323 }
324
325#pragma unroll
326 for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
327 const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
328
329 frag_a_V v_a;
330 wmma::load_matrix_sync(v_a, V_h + int64_t(k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
331#pragma unroll
332 for (int j = 0; j < ncols/frag_n; ++j) {
333 wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
334 }
335 }
336 }
337
338 __syncthreads();
339
340 const int offset_k = (threadIdx.y % VKQ_ratio) * (ncols*D_padded);
341#pragma unroll
342 for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) {
343#pragma unroll
344 for (int j0 = 0; j0 < ncols; j0 += frag_n) {
345 wmma::store_matrix_sync(
346 KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
347 VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n],
348 D_padded, wmma::mem_col_major);
349 }
350 }
351
352 __syncthreads();
353
354#pragma unroll
355 for (int j0 = 0; j0 < ncols; j0 += nwarps) {
356 const int j = j0 + threadIdx.y;
357
358 half2 VKQ_scale;
359 if (std::is_same<KQ_acc_t, float>::value) {
360 VKQ_scale = make_half2(KQ_max_scale_f[j0/nwarps], KQ_max_scale_f[j0/nwarps]);
361 } else {
362 VKQ_scale = KQ_max_scale_h2[j0/nwarps];
363 }
364
365#pragma unroll
366 for (int i0 = 0; i0 < D/2; i0 += warp_size) {
367 const int i = i0 + threadIdx.x;
368 if (i0 + warp_size > D/2 && i >= D/2) {
369 break;
370 }
371
372 half2 VKQ_add = make_half2(0.0f, 0.0f);
373#pragma unroll
374 for (int l = 0; l < VKQ_ratio; ++l) {
375 VKQ_add += KQ2[l*(ncols*D_padded/2) + j*(D_padded/2) + i];
376 }
377 VKQ2[j*(D_padded/2) + i] = VKQ_scale*VKQ2[j*(D_padded/2) + i] + VKQ_add;
378 }
379 }
380
381 __syncthreads();
382 }
383
384 // Apply attention sinks
385 if (sinksf && blockIdx.y == 0) {
386 const float sinkf = sinksf[head];
387 const half sinkh = __float2half(sinkf);
388
389#pragma unroll
390 for (int j0 = 0; j0 < ncols; j0 += nwarps) {
391 const int j = j0 + threadIdx.y;
392
393 if (std::is_same<KQ_acc_t, float>::value) {
394 float kqmax_new = fmaxf(KQ_max_f[j0/nwarps], sinkf);
395
396 const float KQ_max_scale = expf(KQ_max_f[j0/nwarps] - kqmax_new);
397 KQ_max_f[j0/nwarps] = kqmax_new;
398
399 KQ_rowsum_f[j0/nwarps] = KQ_rowsum_f[j0/nwarps] * KQ_max_scale + expf(sinkf - KQ_max_f[j0/nwarps]);
400
401 const half2 scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
402#pragma unroll
403 for (int i0 = 0; i0 < D/2; i0 += warp_size) {
404 const int i = i0 + threadIdx.x;
405 if (i0 + warp_size > D/2 && i >= D/2) break;
406 VKQ2[j*(D_padded/2) + i] *= scale_h2;
407 }
408 } else {
409 half kqmax_old = __low2half(KQ_max_h2[j0/nwarps]);
410 half kqmax_new = fmaxf(kqmax_old, sinkh);
411 KQ_max_h2[j0/nwarps] = __half2half2(kqmax_new);
412
413 const half KQ_max_scale_h = hexp(kqmax_old - kqmax_new);
414 const half2 KQ_max_scale = __half2half2(KQ_max_scale_h);
415
416 KQ_rowsum_h2[j0/nwarps] = KQ_rowsum_h2[j0/nwarps] * KQ_max_scale;
417 const half val = hexp(sinkh - kqmax_new);
418 KQ_rowsum_h2[j0/nwarps].x = __hadd(KQ_rowsum_h2[j0/nwarps].x, val);
419
420#pragma unroll
421 for (int i0 = 0; i0 < D/2; i0 += warp_size) {
422 const int i = i0 + threadIdx.x;
423 if (i0 + warp_size > D/2 && i >= D/2) break;
424 VKQ2[j*(D_padded/2) + i] *= KQ_max_scale;
425 }
426 }
427 }
428
429 __syncthreads();
430 }
431#pragma unroll
432 for (int j0 = 0; j0 < ncols; j0 += nwarps) {
433 const int j_VKQ = j0 + threadIdx.y;
434 if (ic0 + j_VKQ >= ne01) {
435 return;
436 }
437
438 float KQ_rowsum_j;
439 if (std::is_same<KQ_acc_t, float>::value) {
440 KQ_rowsum_j = KQ_rowsum_f[j0/nwarps];
441 } else {
442 KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]);
443 }
444
445 const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
446
447#pragma unroll
448 for (int i0 = 0; i0 < D; i0 += warp_size) {
449 const int i = i0 + threadIdx.x;
450 if (i0 + warp_size > D && i >= D) {
451 break;
452 }
453 float dst_val = VKQ[j_VKQ*D_padded + i];
454 if (gridDim.y == 1) {
455 dst_val /= KQ_rowsum_j;
456 }
457 dst[j_dst_unrolled*D + i] = dst_val;
458 }
459
460 if (gridDim.y == 1 || threadIdx.x != 0) {
461 continue;
462 }
463
464 float2 dst_meta_val;
465 if (std::is_same<KQ_acc_t, float>::value) {
466 dst_meta_val.x = KQ_max_f[j0/nwarps];
467 } else {
468 dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]);
469 }
470 dst_meta_val.y = KQ_rowsum_j;
471 dst_meta[j_dst_unrolled] = dst_meta_val;
472 }
473#else
474 GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
475 max_bias, m0, m1, n_head_log2, logit_softcap,
476 ne00, ne01, ne02, ne03,
477 nb01, nb02, nb03,
478 ne10, ne11, ne12, ne13,
479 nb11, nb12, nb13,
480 nb21, nb22, nb23,
481 ne31, ne32, ne33,
482 nb31, nb32, nb33);
483 NO_DEVICE_CODE;
484#endif // defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_USE_WMMA_FATTN)))
485}
486
487constexpr int get_max_power_of_2(int x) {
488 return x % 2 == 0 ? 2*get_max_power_of_2(x: x/2) : 1;
489}
490
491static_assert(get_max_power_of_2(x: 1) == 1, "Test failed.");
492static_assert(get_max_power_of_2(x: 2) == 2, "Test failed.");
493static_assert(get_max_power_of_2(x: 4) == 4, "Test failed.");
494static_assert(get_max_power_of_2(x: 6) == 2, "Test failed.");
495
496// Number of VKQ rows calculated in parallel:
497constexpr int get_VKQ_stride(int D, int nwarps, int frag_m) {
498 return (get_max_power_of_2(x: D/frag_m) < nwarps ? get_max_power_of_2(x: D/frag_m) : nwarps)*frag_m;
499}
500
501static_assert(get_VKQ_stride(D: 128, nwarps: 1, frag_m: 32) == 32, "Test failed.");
502static_assert(get_VKQ_stride(D: 128, nwarps: 2, frag_m: 32) == 64, "Test failed.");
503static_assert(get_VKQ_stride(D: 128, nwarps: 4, frag_m: 32) == 128, "Test failed.");
504static_assert(get_VKQ_stride( D: 64, nwarps: 1, frag_m: 32) == 32, "Test failed.");
505static_assert(get_VKQ_stride( D: 64, nwarps: 2, frag_m: 32) == 64, "Test failed.");
506static_assert(get_VKQ_stride( D: 64, nwarps: 4, frag_m: 32) == 64, "Test failed.");
507static_assert(get_VKQ_stride( D: 80, nwarps: 1, frag_m: 16) == 16, "Test failed.");
508static_assert(get_VKQ_stride( D: 80, nwarps: 2, frag_m: 16) == 16, "Test failed.");
509static_assert(get_VKQ_stride( D: 80, nwarps: 4, frag_m: 16) == 16, "Test failed.");
510
511template <int D, int cols_per_block, typename KQ_acc_t>
512void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
513 const ggml_tensor * KQV = dst;
514
515 constexpr int nwarps = 4;
516
517 constexpr int frag_m = cols_per_block == 8 && D % 32 == 0 ? 32 : 16;
518 const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
519
520 float logit_softcap;
521 memcpy(dest: &logit_softcap, src: (const float *) KQV->op_params + 2, n: sizeof(float));
522
523 fattn_kernel_t fattn_kernel;
524 if (logit_softcap == 0.0f) {
525 constexpr bool use_logit_softcap = false;
526 fattn_kernel = flash_attn_ext_f16<
527 D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), KQ_acc_t, use_logit_softcap>;
528 } else {
529 constexpr bool use_logit_softcap = true;
530 fattn_kernel = flash_attn_ext_f16<
531 D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), KQ_acc_t, use_logit_softcap>;
532 }
533 launch_fattn<D, cols_per_block, 1>(ctx, dst, fattn_kernel, nwarps, 0, FATTN_KQ_STRIDE, true, true, false, warp_size);
534}
535
536void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
537 const ggml_tensor * KQV = dst;
538 const ggml_tensor * Q = dst->src[0];
539
540 const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
541 const int warp_size = ggml_cuda_info().devices[ctx.device].warp_size;
542
543 if (prec != GGML_PREC_DEFAULT) {
544 if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
545 constexpr int cols_per_block = 16;
546 switch (Q->ne[0]) {
547 case 64:
548 ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst);
549 break;
550 case 80:
551 ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst);
552 break;
553 case 96:
554 ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst);
555 break;
556 case 112:
557 ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst);
558 break;
559 case 128:
560 ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
561 break;
562 case 256:
563 ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst);
564 break;
565 default:
566 GGML_ABORT("fatal error");
567 break;
568 }
569 } else {
570 constexpr int cols_per_block = 32;
571 switch (Q->ne[0]) {
572 case 64:
573 ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst);
574 break;
575 case 80:
576 ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst);
577 break;
578 case 96:
579 ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst);
580 break;
581 case 112:
582 ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst);
583 break;
584 case 128:
585 ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
586 break;
587 // case 256:
588 // ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst);
589 // break;
590 default:
591 GGML_ABORT("fatal error");
592 break;
593 }
594 }
595 return;
596 }
597
598#if !defined(GGML_USE_HIP)
599 if (Q->ne[1] <= 8 && Q->ne[0] % warp_size == 0) {
600 constexpr int cols_per_block = 8;
601 switch (Q->ne[0]) {
602 case 64:
603 ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
604 break;
605 case 96:
606 ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
607 break;
608 case 128:
609 ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
610 break;
611 case 256:
612 ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
613 break;
614 default:
615 GGML_ABORT("fatal error");
616 break;
617 }
618 return;
619 }
620#endif // !defined(GGML_USE_HIP)
621
622 if (Q->ne[1] <= 32) {
623 constexpr int cols_per_block = 16;
624 switch (Q->ne[0]) {
625 case 64:
626 ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
627 break;
628 case 80:
629 ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst);
630 break;
631 case 96:
632 ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
633 break;
634 case 112:
635 ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst);
636 break;
637 case 128:
638 ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
639 break;
640 case 256:
641 ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
642 break;
643 default:
644 GGML_ABORT("fatal error");
645 break;
646 }
647 return;
648 }
649
650 constexpr int cols_per_block = 32;
651 switch (Q->ne[0]) {
652 case 64:
653 ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
654 break;
655 case 80:
656 ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst);
657 break;
658 case 96:
659 ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
660 break;
661 case 112:
662 ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst);
663 break;
664 case 128:
665 ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
666 break;
667 case 256:
668 ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
669 break;
670 default:
671 GGML_ABORT("fatal error");
672 break;
673 }
674}
675