| 1 | // Old and deprecated WMMA FlashAttention implementation. |
| 2 | // It is still needed for Volta since the memory layout of NVIDIA tensor cores changed with Turing. |
| 3 | // Long-term the WMMA code should be replaced with a dedicated Volta implementation. |
| 4 | |
| 5 | #include "common.cuh" |
| 6 | #include "fattn-common.cuh" |
| 7 | #include "fattn-wmma-f16.cuh" |
| 8 | |
| 9 | #ifdef GGML_USE_WMMA_FATTN |
| 10 | #if !defined(GGML_USE_HIP) |
| 11 | #include <mma.h> |
| 12 | #if defined(GGML_USE_MUSA) |
| 13 | namespace wmma = mtmusa::wmma; |
| 14 | #else // GGML_USE_MUSA |
| 15 | namespace wmma = nvcuda::wmma; |
| 16 | #endif // GGML_USE_MUSA |
| 17 | #elif defined(GGML_USE_HIP) |
| 18 | #include <rocwmma/rocwmma.hpp> |
| 19 | namespace wmma = rocwmma; |
| 20 | #endif // !defined(GGML_USE_HIP) |
| 21 | #endif // GGML_USE_WMMA_FATTN |
| 22 | |
| 23 | // D == head size, VKQ_stride == num VKQ rows calculated in parallel: |
| 24 | template<int D, int ncols, int nwarps, int VKQ_stride, typename KQ_acc_t, bool use_logit_softcap> |
| 25 | __launch_bounds__(nwarps*ggml_cuda_get_physical_warp_size(), 1) |
| 26 | static __global__ void flash_attn_ext_f16( |
| 27 | const char * __restrict__ Q, |
| 28 | const char * __restrict__ K, |
| 29 | const char * __restrict__ V, |
| 30 | const char * __restrict__ mask, |
| 31 | const char * __restrict__ sinks, |
| 32 | const int * __restrict__ KV_max, |
| 33 | float * __restrict__ dst, |
| 34 | float2 * __restrict__ dst_meta, |
| 35 | const float scale, |
| 36 | const float max_bias, |
| 37 | const float m0, |
| 38 | const float m1, |
| 39 | const uint32_t n_head_log2, |
| 40 | const float logit_softcap, |
| 41 | const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03, |
| 42 | const int32_t nb01, const int32_t nb02, const int32_t nb03, |
| 43 | const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, |
| 44 | const int32_t nb11, const int32_t nb12, const int64_t nb13, |
| 45 | const int32_t nb21, const int32_t nb22, const int64_t nb23, |
| 46 | const int32_t ne31, const int32_t ne32, const int32_t ne33, |
| 47 | const int32_t nb31, const int32_t nb32, const int64_t nb33) { |
| 48 | #if defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_USE_WMMA_FATTN))) |
| 49 | // Skip unused kernel variants for faster compilation: |
| 50 | if (use_logit_softcap && !(D == 128 || D == 256)) { |
| 51 | NO_DEVICE_CODE; |
| 52 | return; |
| 53 | } |
| 54 | |
| 55 | //In this kernel Q, K, V are matrices while i, j, k are matrix indices. |
| 56 | |
| 57 | constexpr int warp_size = ggml_cuda_get_physical_warp_size(); |
| 58 | |
| 59 | const int ic0 = ncols*blockIdx.x; // Index of the first Q/QKV column to work on. |
| 60 | |
| 61 | static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE." ); |
| 62 | static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16." ); |
| 63 | constexpr int frag_m = ncols == 8 ? 32 : 16; |
| 64 | constexpr int frag_n = ncols == 8 ? 8 : 16; |
| 65 | static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0." ); |
| 66 | typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16, half, wmma::row_major> frag_a_K; |
| 67 | typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16, half, wmma::col_major> frag_a_V; |
| 68 | typedef wmma::fragment<wmma::matrix_b, frag_m, frag_n, 16, half, wmma::col_major> frag_b; |
| 69 | typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t> frag_c_KQ; |
| 70 | typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, half> frag_c_VKQ; |
| 71 | |
| 72 | constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel. |
| 73 | constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy. |
| 74 | static_assert(VKQ_ratio <= nwarps, "VKQ_ratio must be <= nwarps." ); |
| 75 | |
| 76 | // Pad internal representation of KQ, KQV to reduce shared memory bank conflicts: |
| 77 | constexpr int D_padded = D + 8; |
| 78 | constexpr int kqs_padded = FATTN_KQ_STRIDE + 8; |
| 79 | constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half); |
| 80 | |
| 81 | const int sequence = blockIdx.z / ne02; |
| 82 | const int head = blockIdx.z - sequence*ne02; |
| 83 | const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. |
| 84 | const float * Q_f = (const float *) (Q + nb03* sequence + nb02* head + nb01*ic0); |
| 85 | const half * K_h = (const half *) (K + nb13* sequence + nb12*(head / gqa_ratio)); |
| 86 | const half * V_h = (const half *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape |
| 87 | const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0); |
| 88 | const half2 * mask2 = (const half2 *) maskh; |
| 89 | const float * sinksf = (const float *) sinks; |
| 90 | |
| 91 | const int stride_Q = nb01 / sizeof(float); |
| 92 | const int stride_KV = nb11 / sizeof(half); |
| 93 | |
| 94 | const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1); |
| 95 | const half slopeh = __float2half(slopef); |
| 96 | const half2 slope2 = make_half2(slopef, slopef); |
| 97 | |
| 98 | const half2 logit_softcap_2 = make_half2(logit_softcap, logit_softcap); |
| 99 | |
| 100 | frag_b Q_b[D/16][ncols/frag_n]; |
| 101 | |
| 102 | // A single buffer for temporarily holding tiles of KQ and VKQ parts: |
| 103 | constexpr int mem_KQ = ncols*kqs_padded*kqar; |
| 104 | constexpr int mem_VKQ_parts = VKQ_ratio*ncols*D_padded; |
| 105 | __shared__ half KQ[mem_KQ >= mem_VKQ_parts ? mem_KQ : mem_VKQ_parts]; |
| 106 | float * KQ_f = (float *) KQ; |
| 107 | half2 * KQ2 = (half2 *) KQ; |
| 108 | |
| 109 | float KQ_rowsum_f[ncols/nwarps] = {0.0f}; |
| 110 | float KQ_max_f[ncols/nwarps]; |
| 111 | float KQ_max_scale_f[ncols/nwarps] = {0.0f}; |
| 112 | |
| 113 | #pragma unroll |
| 114 | for (int j = 0; j < ncols/nwarps; ++j) { |
| 115 | KQ_max_f[j] = -FLT_MAX/2.0f; |
| 116 | } |
| 117 | |
| 118 | half2 KQ_rowsum_h2[ncols/nwarps] = {{0.0f, 0.0f}}; |
| 119 | half2 KQ_max_h2[ncols/nwarps]; |
| 120 | half2 KQ_max_scale_h2[ncols/nwarps] = {{0.0f, 0.0f}}; |
| 121 | |
| 122 | #pragma unroll |
| 123 | for (int j = 0; j < ncols/nwarps; ++j) { |
| 124 | KQ_max_h2[j] = make_half2(-HALF_MAX_HALF, -HALF_MAX_HALF); |
| 125 | } |
| 126 | |
| 127 | __shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice. |
| 128 | half2 * VKQ2 = (half2 *) VKQ; |
| 129 | #pragma unroll |
| 130 | for (int j0 = 0; j0 < ncols; j0 += nwarps) { |
| 131 | const int j = j0 + threadIdx.y; |
| 132 | #pragma unroll |
| 133 | for (int i0 = 0; i0 < D/2; i0 += warp_size) { |
| 134 | const int i = i0 + threadIdx.x; |
| 135 | if (i0 + warp_size > D/2 && i >= D/2) { |
| 136 | break; |
| 137 | } |
| 138 | VKQ2[j*(D_padded/2) + i] = make_half2(0.0f, 0.0f); |
| 139 | } |
| 140 | } |
| 141 | |
| 142 | // Convert Q to half and apply scale, temporarily store in KQ: |
| 143 | #pragma unroll |
| 144 | for (int j0 = 0; j0 < ncols; j0 += nwarps) { |
| 145 | const int j = j0 + threadIdx.y; |
| 146 | #pragma unroll |
| 147 | for (int i0 = 0; i0 < D; i0 += warp_size) { |
| 148 | const int i = i0 + threadIdx.x; |
| 149 | if (i0 + warp_size > D && i >= D) { |
| 150 | break; |
| 151 | } |
| 152 | KQ[j*D_padded + i] = ic0 + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f; |
| 153 | } |
| 154 | } |
| 155 | |
| 156 | __syncthreads(); |
| 157 | |
| 158 | // Load Q into tensor core fragments/registers since it will be used frequently: |
| 159 | #pragma unroll |
| 160 | for (int i0 = 0; i0 < D; i0 += 16) { |
| 161 | #pragma unroll |
| 162 | for (int j0 = 0; j0 < ncols; j0 += frag_n) { |
| 163 | wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded); |
| 164 | } |
| 165 | } |
| 166 | |
| 167 | __syncthreads(); |
| 168 | |
| 169 | // Iterate over ne11 == previous tokens: |
| 170 | const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11; |
| 171 | for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE) { |
| 172 | // Calculate tile of KQ: |
| 173 | #pragma unroll |
| 174 | for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) { |
| 175 | frag_c_KQ KQ_c[ncols/frag_n]; |
| 176 | #pragma unroll |
| 177 | for (int j = 0; j < ncols/frag_n; ++j) { |
| 178 | wmma::fill_fragment(KQ_c[j], static_cast<KQ_acc_t>(0.0f)); |
| 179 | } |
| 180 | #pragma unroll |
| 181 | for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) { |
| 182 | frag_a_K K_a; |
| 183 | wmma::load_matrix_sync(K_a, K_h + int64_t(k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV); |
| 184 | #pragma unroll |
| 185 | for (int j = 0; j < ncols/frag_n; ++j) { |
| 186 | wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]); |
| 187 | } |
| 188 | } |
| 189 | #pragma unroll |
| 190 | for (int j0 = 0; j0 < ncols; j0 += frag_n) { |
| 191 | wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, wmma::mem_col_major); |
| 192 | } |
| 193 | } |
| 194 | |
| 195 | __syncthreads(); |
| 196 | |
| 197 | // Calculate softmax for each KQ column using the current max. value. |
| 198 | // The divisor is stored in KQ_rowsum and will be applied at the end. |
| 199 | #pragma unroll |
| 200 | for (int j0 = 0; j0 < ncols; j0 += nwarps) { |
| 201 | const int j = j0 + threadIdx.y; |
| 202 | |
| 203 | if (std::is_same<KQ_acc_t, float>::value) { |
| 204 | float KQ_f_tmp[FATTN_KQ_STRIDE / warp_size]; |
| 205 | #pragma unroll |
| 206 | for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += warp_size) { |
| 207 | const int k = k0 + threadIdx.x; |
| 208 | |
| 209 | KQ_f_tmp[k0/warp_size] = KQ_f[j*kqs_padded + k]; |
| 210 | |
| 211 | if (use_logit_softcap) { |
| 212 | KQ_f_tmp[k0/warp_size] = logit_softcap*tanhf(KQ_f_tmp[k0/warp_size]); |
| 213 | } |
| 214 | } |
| 215 | |
| 216 | float KQ_max_new = KQ_max_f[j0/nwarps]; |
| 217 | #pragma unroll |
| 218 | for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += warp_size) { |
| 219 | const int k = k0 + threadIdx.x; |
| 220 | |
| 221 | KQ_f_tmp[k0/warp_size] += mask ? __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f; |
| 222 | KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/warp_size]); |
| 223 | } |
| 224 | KQ_max_new = warp_reduce_max<warp_size>(KQ_max_new); |
| 225 | |
| 226 | const float diff = KQ_max_f[j0/nwarps] - KQ_max_new; |
| 227 | KQ_max_scale_f[j0/nwarps] = expf(diff); |
| 228 | if (diff <= SOFTMAX_FTZ_THRESHOLD) { |
| 229 | KQ_max_scale_f[j0/nwarps] = 0.0f; |
| 230 | } |
| 231 | KQ_max_f[j0/nwarps] = KQ_max_new; |
| 232 | |
| 233 | float KQ_rowsum_add = 0.0f; |
| 234 | #pragma unroll |
| 235 | for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += warp_size) { |
| 236 | const int k = k0 + threadIdx.x; |
| 237 | |
| 238 | const float diff = KQ_f_tmp[k0/warp_size] - KQ_max_f[j0/nwarps]; |
| 239 | KQ_f_tmp[k0/warp_size] = expf(diff); |
| 240 | if (diff <= SOFTMAX_FTZ_THRESHOLD) { |
| 241 | KQ_f_tmp[k0/warp_size] = 0.0f; |
| 242 | } |
| 243 | KQ_rowsum_add += KQ_f_tmp[k0/warp_size]; |
| 244 | KQ[j*(kqar*kqs_padded) + k] = KQ_f_tmp[k0/warp_size]; |
| 245 | } |
| 246 | KQ_rowsum_add = warp_reduce_sum<warp_size>(KQ_rowsum_add); |
| 247 | |
| 248 | // Scale previous KQ_rowsum to account for a potential increase in KQ_max: |
| 249 | KQ_rowsum_f[j0/nwarps] = KQ_max_scale_f[j0/nwarps]*KQ_rowsum_f[j0/nwarps] + KQ_rowsum_add; |
| 250 | } else { |
| 251 | half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*warp_size)]; |
| 252 | #pragma unroll |
| 253 | for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += warp_size) { |
| 254 | const int k = k0 + threadIdx.x; |
| 255 | |
| 256 | KQ2_tmp[k0/warp_size] = KQ2[j*(kqs_padded/2) + k]; |
| 257 | |
| 258 | if (use_logit_softcap) { |
| 259 | // There is no dedicated tangens hyperbolicus function for half2. |
| 260 | KQ2_tmp[k0/warp_size] = h2exp(KQ2_tmp[k0/warp_size]*make_half2(2.0f, 2.0f)); |
| 261 | KQ2_tmp[k0/warp_size] = (KQ2_tmp[k0/warp_size] - make_half2(1.0f, 1.0f)) |
| 262 | /(KQ2_tmp[k0/warp_size] + make_half2(1.0f, 1.0f)); |
| 263 | |
| 264 | KQ2_tmp[k0/warp_size] *= logit_softcap_2; |
| 265 | } |
| 266 | } |
| 267 | |
| 268 | half2 KQ_max_new = KQ_max_h2[j0/nwarps]; |
| 269 | #pragma unroll |
| 270 | for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += warp_size) { |
| 271 | const int k = k0 + threadIdx.x; |
| 272 | |
| 273 | KQ2_tmp[k0/warp_size] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f); |
| 274 | KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/warp_size]); |
| 275 | } |
| 276 | KQ_max_new = __half2half2(warp_reduce_max<warp_size>(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new)))); |
| 277 | const half2 diff = KQ_max_h2[j0/nwarps] - KQ_max_new; |
| 278 | KQ_max_scale_h2[j0/nwarps] = h2exp(diff); |
| 279 | const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD)); |
| 280 | *((uint32_t *) &KQ_max_scale_h2[j0/nwarps]) &= ftz_mask; |
| 281 | KQ_max_h2[j0/nwarps] = KQ_max_new; |
| 282 | |
| 283 | half2 KQ_rowsum_add = make_half2(0.0f, 0.0f); |
| 284 | #pragma unroll |
| 285 | for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += warp_size) { |
| 286 | const int k = k0 + threadIdx.x; |
| 287 | |
| 288 | const half2 diff = KQ2_tmp[k0/warp_size] - KQ_max_h2[j0/nwarps]; |
| 289 | KQ2_tmp[k0/warp_size] = h2exp(diff); |
| 290 | const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD)); |
| 291 | *((uint32_t *) &KQ2_tmp[k0/warp_size]) &= ftz_mask; |
| 292 | KQ_rowsum_add += KQ2_tmp[k0/warp_size]; |
| 293 | KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/warp_size]; |
| 294 | } |
| 295 | KQ_rowsum_add = warp_reduce_sum<warp_size>(KQ_rowsum_add); |
| 296 | |
| 297 | // Scale previous KQ_rowsum to account for a potential increase in KQ_max: |
| 298 | KQ_rowsum_h2[j0/nwarps] = KQ_max_scale_h2[j0/nwarps]*KQ_rowsum_h2[j0/nwarps] + KQ_rowsum_add; |
| 299 | } |
| 300 | } |
| 301 | |
| 302 | __syncthreads(); |
| 303 | |
| 304 | frag_b KQ_b[FATTN_KQ_STRIDE/(VKQ_ratio*16)][ncols/frag_n]; |
| 305 | #pragma unroll |
| 306 | for (int j0 = 0; j0 < ncols; j0 += frag_n) { |
| 307 | #pragma unroll |
| 308 | for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) { |
| 309 | const int k = k0 + (threadIdx.y % VKQ_ratio)*16; |
| 310 | wmma::load_matrix_sync( |
| 311 | KQ_b[k0/(VKQ_ratio*16)][j0/frag_n], |
| 312 | KQ + j0*(kqar*kqs_padded) + k, |
| 313 | kqar*kqs_padded); |
| 314 | } |
| 315 | } |
| 316 | |
| 317 | frag_c_VKQ VKQ_c[D/VKQ_stride][ncols/frag_n]; |
| 318 | #pragma unroll |
| 319 | for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) { |
| 320 | #pragma unroll |
| 321 | for (int j = 0; j < ncols/frag_n; ++j) { |
| 322 | wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], static_cast<half>(0.0f)); |
| 323 | } |
| 324 | |
| 325 | #pragma unroll |
| 326 | for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) { |
| 327 | const int k = k0 + (threadIdx.y % VKQ_ratio)*16; |
| 328 | |
| 329 | frag_a_V v_a; |
| 330 | wmma::load_matrix_sync(v_a, V_h + int64_t(k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV); |
| 331 | #pragma unroll |
| 332 | for (int j = 0; j < ncols/frag_n; ++j) { |
| 333 | wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]); |
| 334 | } |
| 335 | } |
| 336 | } |
| 337 | |
| 338 | __syncthreads(); |
| 339 | |
| 340 | const int offset_k = (threadIdx.y % VKQ_ratio) * (ncols*D_padded); |
| 341 | #pragma unroll |
| 342 | for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) { |
| 343 | #pragma unroll |
| 344 | for (int j0 = 0; j0 < ncols; j0 += frag_n) { |
| 345 | wmma::store_matrix_sync( |
| 346 | KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio), |
| 347 | VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n], |
| 348 | D_padded, wmma::mem_col_major); |
| 349 | } |
| 350 | } |
| 351 | |
| 352 | __syncthreads(); |
| 353 | |
| 354 | #pragma unroll |
| 355 | for (int j0 = 0; j0 < ncols; j0 += nwarps) { |
| 356 | const int j = j0 + threadIdx.y; |
| 357 | |
| 358 | half2 VKQ_scale; |
| 359 | if (std::is_same<KQ_acc_t, float>::value) { |
| 360 | VKQ_scale = make_half2(KQ_max_scale_f[j0/nwarps], KQ_max_scale_f[j0/nwarps]); |
| 361 | } else { |
| 362 | VKQ_scale = KQ_max_scale_h2[j0/nwarps]; |
| 363 | } |
| 364 | |
| 365 | #pragma unroll |
| 366 | for (int i0 = 0; i0 < D/2; i0 += warp_size) { |
| 367 | const int i = i0 + threadIdx.x; |
| 368 | if (i0 + warp_size > D/2 && i >= D/2) { |
| 369 | break; |
| 370 | } |
| 371 | |
| 372 | half2 VKQ_add = make_half2(0.0f, 0.0f); |
| 373 | #pragma unroll |
| 374 | for (int l = 0; l < VKQ_ratio; ++l) { |
| 375 | VKQ_add += KQ2[l*(ncols*D_padded/2) + j*(D_padded/2) + i]; |
| 376 | } |
| 377 | VKQ2[j*(D_padded/2) + i] = VKQ_scale*VKQ2[j*(D_padded/2) + i] + VKQ_add; |
| 378 | } |
| 379 | } |
| 380 | |
| 381 | __syncthreads(); |
| 382 | } |
| 383 | |
| 384 | // Apply attention sinks |
| 385 | if (sinksf && blockIdx.y == 0) { |
| 386 | const float sinkf = sinksf[head]; |
| 387 | const half sinkh = __float2half(sinkf); |
| 388 | |
| 389 | #pragma unroll |
| 390 | for (int j0 = 0; j0 < ncols; j0 += nwarps) { |
| 391 | const int j = j0 + threadIdx.y; |
| 392 | |
| 393 | if (std::is_same<KQ_acc_t, float>::value) { |
| 394 | float kqmax_new = fmaxf(KQ_max_f[j0/nwarps], sinkf); |
| 395 | |
| 396 | const float KQ_max_scale = expf(KQ_max_f[j0/nwarps] - kqmax_new); |
| 397 | KQ_max_f[j0/nwarps] = kqmax_new; |
| 398 | |
| 399 | KQ_rowsum_f[j0/nwarps] = KQ_rowsum_f[j0/nwarps] * KQ_max_scale + expf(sinkf - KQ_max_f[j0/nwarps]); |
| 400 | |
| 401 | const half2 scale_h2 = make_half2(KQ_max_scale, KQ_max_scale); |
| 402 | #pragma unroll |
| 403 | for (int i0 = 0; i0 < D/2; i0 += warp_size) { |
| 404 | const int i = i0 + threadIdx.x; |
| 405 | if (i0 + warp_size > D/2 && i >= D/2) break; |
| 406 | VKQ2[j*(D_padded/2) + i] *= scale_h2; |
| 407 | } |
| 408 | } else { |
| 409 | half kqmax_old = __low2half(KQ_max_h2[j0/nwarps]); |
| 410 | half kqmax_new = fmaxf(kqmax_old, sinkh); |
| 411 | KQ_max_h2[j0/nwarps] = __half2half2(kqmax_new); |
| 412 | |
| 413 | const half KQ_max_scale_h = hexp(kqmax_old - kqmax_new); |
| 414 | const half2 KQ_max_scale = __half2half2(KQ_max_scale_h); |
| 415 | |
| 416 | KQ_rowsum_h2[j0/nwarps] = KQ_rowsum_h2[j0/nwarps] * KQ_max_scale; |
| 417 | const half val = hexp(sinkh - kqmax_new); |
| 418 | KQ_rowsum_h2[j0/nwarps].x = __hadd(KQ_rowsum_h2[j0/nwarps].x, val); |
| 419 | |
| 420 | #pragma unroll |
| 421 | for (int i0 = 0; i0 < D/2; i0 += warp_size) { |
| 422 | const int i = i0 + threadIdx.x; |
| 423 | if (i0 + warp_size > D/2 && i >= D/2) break; |
| 424 | VKQ2[j*(D_padded/2) + i] *= KQ_max_scale; |
| 425 | } |
| 426 | } |
| 427 | } |
| 428 | |
| 429 | __syncthreads(); |
| 430 | } |
| 431 | #pragma unroll |
| 432 | for (int j0 = 0; j0 < ncols; j0 += nwarps) { |
| 433 | const int j_VKQ = j0 + threadIdx.y; |
| 434 | if (ic0 + j_VKQ >= ne01) { |
| 435 | return; |
| 436 | } |
| 437 | |
| 438 | float KQ_rowsum_j; |
| 439 | if (std::is_same<KQ_acc_t, float>::value) { |
| 440 | KQ_rowsum_j = KQ_rowsum_f[j0/nwarps]; |
| 441 | } else { |
| 442 | KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]); |
| 443 | } |
| 444 | |
| 445 | const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y; |
| 446 | |
| 447 | #pragma unroll |
| 448 | for (int i0 = 0; i0 < D; i0 += warp_size) { |
| 449 | const int i = i0 + threadIdx.x; |
| 450 | if (i0 + warp_size > D && i >= D) { |
| 451 | break; |
| 452 | } |
| 453 | float dst_val = VKQ[j_VKQ*D_padded + i]; |
| 454 | if (gridDim.y == 1) { |
| 455 | dst_val /= KQ_rowsum_j; |
| 456 | } |
| 457 | dst[j_dst_unrolled*D + i] = dst_val; |
| 458 | } |
| 459 | |
| 460 | if (gridDim.y == 1 || threadIdx.x != 0) { |
| 461 | continue; |
| 462 | } |
| 463 | |
| 464 | float2 dst_meta_val; |
| 465 | if (std::is_same<KQ_acc_t, float>::value) { |
| 466 | dst_meta_val.x = KQ_max_f[j0/nwarps]; |
| 467 | } else { |
| 468 | dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]); |
| 469 | } |
| 470 | dst_meta_val.y = KQ_rowsum_j; |
| 471 | dst_meta[j_dst_unrolled] = dst_meta_val; |
| 472 | } |
| 473 | #else |
| 474 | GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, |
| 475 | max_bias, m0, m1, n_head_log2, logit_softcap, |
| 476 | ne00, ne01, ne02, ne03, |
| 477 | nb01, nb02, nb03, |
| 478 | ne10, ne11, ne12, ne13, |
| 479 | nb11, nb12, nb13, |
| 480 | nb21, nb22, nb23, |
| 481 | ne31, ne32, ne33, |
| 482 | nb31, nb32, nb33); |
| 483 | NO_DEVICE_CODE; |
| 484 | #endif // defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_USE_WMMA_FATTN))) |
| 485 | } |
| 486 | |
| 487 | constexpr int get_max_power_of_2(int x) { |
| 488 | return x % 2 == 0 ? 2*get_max_power_of_2(x: x/2) : 1; |
| 489 | } |
| 490 | |
| 491 | static_assert(get_max_power_of_2(x: 1) == 1, "Test failed." ); |
| 492 | static_assert(get_max_power_of_2(x: 2) == 2, "Test failed." ); |
| 493 | static_assert(get_max_power_of_2(x: 4) == 4, "Test failed." ); |
| 494 | static_assert(get_max_power_of_2(x: 6) == 2, "Test failed." ); |
| 495 | |
| 496 | // Number of VKQ rows calculated in parallel: |
| 497 | constexpr int get_VKQ_stride(int D, int nwarps, int frag_m) { |
| 498 | return (get_max_power_of_2(x: D/frag_m) < nwarps ? get_max_power_of_2(x: D/frag_m) : nwarps)*frag_m; |
| 499 | } |
| 500 | |
| 501 | static_assert(get_VKQ_stride(D: 128, nwarps: 1, frag_m: 32) == 32, "Test failed." ); |
| 502 | static_assert(get_VKQ_stride(D: 128, nwarps: 2, frag_m: 32) == 64, "Test failed." ); |
| 503 | static_assert(get_VKQ_stride(D: 128, nwarps: 4, frag_m: 32) == 128, "Test failed." ); |
| 504 | static_assert(get_VKQ_stride( D: 64, nwarps: 1, frag_m: 32) == 32, "Test failed." ); |
| 505 | static_assert(get_VKQ_stride( D: 64, nwarps: 2, frag_m: 32) == 64, "Test failed." ); |
| 506 | static_assert(get_VKQ_stride( D: 64, nwarps: 4, frag_m: 32) == 64, "Test failed." ); |
| 507 | static_assert(get_VKQ_stride( D: 80, nwarps: 1, frag_m: 16) == 16, "Test failed." ); |
| 508 | static_assert(get_VKQ_stride( D: 80, nwarps: 2, frag_m: 16) == 16, "Test failed." ); |
| 509 | static_assert(get_VKQ_stride( D: 80, nwarps: 4, frag_m: 16) == 16, "Test failed." ); |
| 510 | |
| 511 | template <int D, int cols_per_block, typename KQ_acc_t> |
| 512 | void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
| 513 | const ggml_tensor * KQV = dst; |
| 514 | |
| 515 | constexpr int nwarps = 4; |
| 516 | |
| 517 | constexpr int frag_m = cols_per_block == 8 && D % 32 == 0 ? 32 : 16; |
| 518 | const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size; |
| 519 | |
| 520 | float logit_softcap; |
| 521 | memcpy(dest: &logit_softcap, src: (const float *) KQV->op_params + 2, n: sizeof(float)); |
| 522 | |
| 523 | fattn_kernel_t fattn_kernel; |
| 524 | if (logit_softcap == 0.0f) { |
| 525 | constexpr bool use_logit_softcap = false; |
| 526 | fattn_kernel = flash_attn_ext_f16< |
| 527 | D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), KQ_acc_t, use_logit_softcap>; |
| 528 | } else { |
| 529 | constexpr bool use_logit_softcap = true; |
| 530 | fattn_kernel = flash_attn_ext_f16< |
| 531 | D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), KQ_acc_t, use_logit_softcap>; |
| 532 | } |
| 533 | launch_fattn<D, cols_per_block, 1>(ctx, dst, fattn_kernel, nwarps, 0, FATTN_KQ_STRIDE, true, true, false, warp_size); |
| 534 | } |
| 535 | |
| 536 | void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
| 537 | const ggml_tensor * KQV = dst; |
| 538 | const ggml_tensor * Q = dst->src[0]; |
| 539 | |
| 540 | const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV); |
| 541 | const int warp_size = ggml_cuda_info().devices[ctx.device].warp_size; |
| 542 | |
| 543 | if (prec != GGML_PREC_DEFAULT) { |
| 544 | if (Q->ne[1] <= 32 || Q->ne[0] > 128) { |
| 545 | constexpr int cols_per_block = 16; |
| 546 | switch (Q->ne[0]) { |
| 547 | case 64: |
| 548 | ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst); |
| 549 | break; |
| 550 | case 80: |
| 551 | ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst); |
| 552 | break; |
| 553 | case 96: |
| 554 | ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst); |
| 555 | break; |
| 556 | case 112: |
| 557 | ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst); |
| 558 | break; |
| 559 | case 128: |
| 560 | ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst); |
| 561 | break; |
| 562 | case 256: |
| 563 | ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst); |
| 564 | break; |
| 565 | default: |
| 566 | GGML_ABORT("fatal error" ); |
| 567 | break; |
| 568 | } |
| 569 | } else { |
| 570 | constexpr int cols_per_block = 32; |
| 571 | switch (Q->ne[0]) { |
| 572 | case 64: |
| 573 | ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst); |
| 574 | break; |
| 575 | case 80: |
| 576 | ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst); |
| 577 | break; |
| 578 | case 96: |
| 579 | ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst); |
| 580 | break; |
| 581 | case 112: |
| 582 | ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst); |
| 583 | break; |
| 584 | case 128: |
| 585 | ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst); |
| 586 | break; |
| 587 | // case 256: |
| 588 | // ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst); |
| 589 | // break; |
| 590 | default: |
| 591 | GGML_ABORT("fatal error" ); |
| 592 | break; |
| 593 | } |
| 594 | } |
| 595 | return; |
| 596 | } |
| 597 | |
| 598 | #if !defined(GGML_USE_HIP) |
| 599 | if (Q->ne[1] <= 8 && Q->ne[0] % warp_size == 0) { |
| 600 | constexpr int cols_per_block = 8; |
| 601 | switch (Q->ne[0]) { |
| 602 | case 64: |
| 603 | ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst); |
| 604 | break; |
| 605 | case 96: |
| 606 | ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst); |
| 607 | break; |
| 608 | case 128: |
| 609 | ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst); |
| 610 | break; |
| 611 | case 256: |
| 612 | ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst); |
| 613 | break; |
| 614 | default: |
| 615 | GGML_ABORT("fatal error" ); |
| 616 | break; |
| 617 | } |
| 618 | return; |
| 619 | } |
| 620 | #endif // !defined(GGML_USE_HIP) |
| 621 | |
| 622 | if (Q->ne[1] <= 32) { |
| 623 | constexpr int cols_per_block = 16; |
| 624 | switch (Q->ne[0]) { |
| 625 | case 64: |
| 626 | ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst); |
| 627 | break; |
| 628 | case 80: |
| 629 | ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst); |
| 630 | break; |
| 631 | case 96: |
| 632 | ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst); |
| 633 | break; |
| 634 | case 112: |
| 635 | ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst); |
| 636 | break; |
| 637 | case 128: |
| 638 | ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst); |
| 639 | break; |
| 640 | case 256: |
| 641 | ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst); |
| 642 | break; |
| 643 | default: |
| 644 | GGML_ABORT("fatal error" ); |
| 645 | break; |
| 646 | } |
| 647 | return; |
| 648 | } |
| 649 | |
| 650 | constexpr int cols_per_block = 32; |
| 651 | switch (Q->ne[0]) { |
| 652 | case 64: |
| 653 | ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst); |
| 654 | break; |
| 655 | case 80: |
| 656 | ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst); |
| 657 | break; |
| 658 | case 96: |
| 659 | ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst); |
| 660 | break; |
| 661 | case 112: |
| 662 | ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst); |
| 663 | break; |
| 664 | case 128: |
| 665 | ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst); |
| 666 | break; |
| 667 | case 256: |
| 668 | ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst); |
| 669 | break; |
| 670 | default: |
| 671 | GGML_ABORT("fatal error" ); |
| 672 | break; |
| 673 | } |
| 674 | } |
| 675 | |