| 1 | #include "common.cuh" |
| 2 | #include "cp-async.cuh" |
| 3 | #include "mma.cuh" |
| 4 | #include "fattn-common.cuh" |
| 5 | |
| 6 | using namespace ggml_cuda_mma; |
| 7 | |
| 8 | typedef tile<16, 8, half2> tile_A; |
| 9 | typedef tile< 8, 8, half2> tile_B; |
| 10 | typedef tile<16, 8, half2> tile_B_16; |
| 11 | typedef tile<16, 8, float> tile_C_KQ; |
| 12 | typedef tile<16, 16, float> tile_C_KQ_16; |
| 13 | typedef tile<16, 4, half2> tile_C_VKQ; |
| 14 | typedef tile<16, 8, half2> tile_C_VKQ_16; |
| 15 | |
| 16 | // Config options for specific head sizes. |
| 17 | // Should not affect results, only speed/register pressure/shared memory use. |
| 18 | // |
| 19 | // nbatch_fa: number of KV rows per softmax rescaling of KQ rowsums and VKQ accumulators. |
| 20 | // nwarps_max: maximum number of warps per CUDA block, up to 8 warps in total can run per SM (given enough shared memory). |
| 21 | // Q_in_reg: whether the Q values should be kept permanently in registers. |
| 22 | // nstages_target: targeted number of pipeline stages for cp_async (if available), 0 means synchronous data loading. |
| 23 | // nbatch_K2: number of K half2 values in direction of DKQ to load in parallel. |
| 24 | // nbatch_V2: number of V half2 values in direction of DV to load in parallel. |
| 25 | // nbatch_combine: number of VKQ half2 values in direction of DV to combine in parallel. |
| 26 | |
| 27 | template <int DKQ, int DV> |
| 28 | struct fattn_mma_f16_config; |
| 29 | |
| 30 | template <> |
| 31 | struct fattn_mma_f16_config< 64, 64> { |
| 32 | static constexpr int nbatch_fa = 64; |
| 33 | static constexpr int nwarps_max = 4; |
| 34 | static constexpr bool Q_in_reg = true; |
| 35 | static constexpr int nstages_target = 2; |
| 36 | |
| 37 | static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { |
| 38 | return 32; |
| 39 | } |
| 40 | |
| 41 | static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { |
| 42 | return 32; |
| 43 | } |
| 44 | |
| 45 | static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { |
| 46 | return 32; |
| 47 | } |
| 48 | |
| 49 | static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { |
| 50 | return 32; |
| 51 | } |
| 52 | |
| 53 | static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { |
| 54 | return 32; |
| 55 | } |
| 56 | |
| 57 | static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { |
| 58 | return 32; |
| 59 | } |
| 60 | }; |
| 61 | |
| 62 | template <> |
| 63 | struct fattn_mma_f16_config< 80, 80> { |
| 64 | static constexpr int nbatch_fa = 64; |
| 65 | static constexpr int nwarps_max = 4; |
| 66 | static constexpr bool Q_in_reg = true; |
| 67 | static constexpr int nstages_target = 2; |
| 68 | |
| 69 | static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { |
| 70 | return 40; |
| 71 | } |
| 72 | |
| 73 | static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { |
| 74 | return 40; |
| 75 | } |
| 76 | |
| 77 | static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { |
| 78 | return 40; |
| 79 | } |
| 80 | |
| 81 | static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { |
| 82 | return 40; |
| 83 | } |
| 84 | |
| 85 | static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { |
| 86 | return 40; |
| 87 | } |
| 88 | |
| 89 | static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { |
| 90 | return 40; |
| 91 | } |
| 92 | }; |
| 93 | |
| 94 | template <> |
| 95 | struct fattn_mma_f16_config< 96, 96> { |
| 96 | static constexpr int nbatch_fa = 64; |
| 97 | static constexpr int nwarps_max = 4; |
| 98 | static constexpr bool Q_in_reg = true; |
| 99 | static constexpr int nstages_target = 2; |
| 100 | |
| 101 | static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { |
| 102 | return 48; |
| 103 | } |
| 104 | |
| 105 | static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { |
| 106 | return 48; |
| 107 | } |
| 108 | |
| 109 | static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { |
| 110 | return 48; |
| 111 | } |
| 112 | |
| 113 | static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { |
| 114 | return 48; |
| 115 | } |
| 116 | |
| 117 | static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { |
| 118 | return 48; |
| 119 | } |
| 120 | |
| 121 | static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { |
| 122 | return 48; |
| 123 | } |
| 124 | }; |
| 125 | |
| 126 | template <> |
| 127 | struct fattn_mma_f16_config<112, 112> { |
| 128 | static constexpr int nbatch_fa = 64; |
| 129 | static constexpr int nwarps_max = 4; |
| 130 | static constexpr bool Q_in_reg = true; |
| 131 | static constexpr int nstages_target = 2; |
| 132 | |
| 133 | static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { |
| 134 | return 56; |
| 135 | } |
| 136 | |
| 137 | static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { |
| 138 | return 56; |
| 139 | } |
| 140 | |
| 141 | static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { |
| 142 | return 56; |
| 143 | } |
| 144 | |
| 145 | static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { |
| 146 | return 56; |
| 147 | } |
| 148 | |
| 149 | static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { |
| 150 | return 56; |
| 151 | } |
| 152 | |
| 153 | static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { |
| 154 | return 56; |
| 155 | } |
| 156 | }; |
| 157 | |
| 158 | template <> |
| 159 | struct fattn_mma_f16_config<128, 128> { |
| 160 | static constexpr int nbatch_fa = 64; |
| 161 | static constexpr int nwarps_max = 4; |
| 162 | static constexpr bool Q_in_reg = true; |
| 163 | static constexpr int nstages_target = 2; |
| 164 | |
| 165 | static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { |
| 166 | return 64; |
| 167 | } |
| 168 | |
| 169 | static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { |
| 170 | return 64; |
| 171 | } |
| 172 | |
| 173 | static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { |
| 174 | return 64; |
| 175 | } |
| 176 | |
| 177 | static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { |
| 178 | return 64; |
| 179 | } |
| 180 | |
| 181 | static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { |
| 182 | return 64; |
| 183 | } |
| 184 | |
| 185 | static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { |
| 186 | return 64; |
| 187 | } |
| 188 | }; |
| 189 | |
| 190 | template <> |
| 191 | struct fattn_mma_f16_config<256, 256> { |
| 192 | static constexpr int nbatch_fa = 32; |
| 193 | static constexpr int nwarps_max = 4; |
| 194 | static constexpr bool Q_in_reg = true; |
| 195 | static constexpr int nstages_target = 2; |
| 196 | |
| 197 | static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { |
| 198 | return 128; |
| 199 | } |
| 200 | |
| 201 | static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { |
| 202 | return 128; |
| 203 | } |
| 204 | |
| 205 | static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { |
| 206 | return 128; |
| 207 | } |
| 208 | |
| 209 | static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { |
| 210 | return 128; |
| 211 | } |
| 212 | |
| 213 | static int get_nbatch_combine_host(const int cc, const int ncols) { |
| 214 | if (ggml_cuda_highest_compiled_arch(arch: cc) == GGML_CUDA_CC_TURING) { |
| 215 | return ncols <= 16 ? 128 : 64; |
| 216 | } |
| 217 | return 64; |
| 218 | } |
| 219 | |
| 220 | static constexpr __device__ int get_nbatch_combine_device(int ncols) { |
| 221 | #if __CUDA_ARCH__ == GGML_CUDA_CC_TURING |
| 222 | return ncols <= 16 ? 128 : 64; |
| 223 | #else |
| 224 | GGML_UNUSED(ncols); |
| 225 | return 128; |
| 226 | #endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING |
| 227 | } |
| 228 | }; |
| 229 | |
| 230 | template <> |
| 231 | struct fattn_mma_f16_config<576, 512> { |
| 232 | static constexpr int nbatch_fa = 32; |
| 233 | static constexpr int nwarps_max = 8; |
| 234 | static constexpr bool Q_in_reg = false; |
| 235 | static constexpr int nstages_target = 1; |
| 236 | |
| 237 | static int get_nbatch_K2_host(const int cc, const int ncols) { |
| 238 | if (ggml_cuda_highest_compiled_arch(arch: cc) == GGML_CUDA_CC_TURING) { |
| 239 | return ncols <= 16 ? 96 : 160; |
| 240 | } |
| 241 | return ncols <= 16 ? 288 : 160; |
| 242 | } |
| 243 | |
| 244 | static constexpr __device__ int get_nbatch_K2_device(int ncols) { |
| 245 | #if __CUDA_ARCH__ == GGML_CUDA_CC_TURING |
| 246 | return ncols <= 16 ? 96 : 160; |
| 247 | #else |
| 248 | return ncols <= 16 ? 288 : 160; |
| 249 | #endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING |
| 250 | } |
| 251 | |
| 252 | static int get_nbatch_V2_host(const int cc, const int ncols) { |
| 253 | if (ggml_cuda_highest_compiled_arch(arch: cc) == GGML_CUDA_CC_TURING) { |
| 254 | return ncols <= 16 ? 64 : 128; |
| 255 | } |
| 256 | return ncols <= 16 ? 256 : 128; |
| 257 | } |
| 258 | |
| 259 | static constexpr __device__ int get_nbatch_V2_device(int ncols) { |
| 260 | #if __CUDA_ARCH__ == GGML_CUDA_CC_TURING |
| 261 | return ncols <= 16 ? 64 : 128; |
| 262 | #else |
| 263 | return ncols <= 16 ? 256 : 128; |
| 264 | #endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING |
| 265 | } |
| 266 | |
| 267 | static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { |
| 268 | return 128; |
| 269 | } |
| 270 | |
| 271 | static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { |
| 272 | return 128; |
| 273 | } |
| 274 | }; |
| 275 | |
| 276 | // ------------------------------------------------------------------------------------------------------------------ |
| 277 | |
| 278 | template<int stride_tile, int nwarps, int nbatch_fa, bool use_cp_async> |
| 279 | static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( |
| 280 | const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int D2, const int stride_KV) { |
| 281 | |
| 282 | // K/V data is loaded with decreasing granularity for D for better memory bandwidth. |
| 283 | // The minimum granularity with cp.async is 16 bytes, with synchronous data loading it's 4 bytes. |
| 284 | |
| 285 | if (use_cp_async) { |
| 286 | constexpr int preload = 64; |
| 287 | constexpr int h2_per_chunk = 16/sizeof(half2); |
| 288 | const int chunks_per_row = D2 / h2_per_chunk; |
| 289 | |
| 290 | const unsigned int tile_KV_32 = ggml_cuda_cvta_generic_to_shared(generic_ptr: tile_KV); |
| 291 | |
| 292 | auto load = [&] __device__ (auto n) { |
| 293 | const int stride_k = WARP_SIZE >> n; |
| 294 | const int k0_start = stride_k == WARP_SIZE ? 0 : chunks_per_row - chunks_per_row % (2*stride_k); |
| 295 | const int k0_stop = chunks_per_row - chunks_per_row % (1*stride_k); |
| 296 | const int stride_i = WARP_SIZE / stride_k; |
| 297 | |
| 298 | if (k0_start == k0_stop) { |
| 299 | return; |
| 300 | } |
| 301 | |
| 302 | #pragma unroll |
| 303 | for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) { |
| 304 | const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); |
| 305 | |
| 306 | if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) { |
| 307 | break; |
| 308 | } |
| 309 | |
| 310 | #pragma unroll |
| 311 | for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { |
| 312 | const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); |
| 313 | |
| 314 | cp_async_cg_16<preload>(dst: tile_KV_32 + i*(stride_tile*sizeof(half2)) + k*16, src: KV + i*stride_KV + k*h2_per_chunk); |
| 315 | } |
| 316 | } |
| 317 | }; |
| 318 | ggml_cuda_unroll<5>{}(load); |
| 319 | } else { |
| 320 | static_assert(nbatch_fa % (4*nwarps) == 0, "out of bounds" ); |
| 321 | auto load = [&] __device__ (const int n) { |
| 322 | const int stride_k = WARP_SIZE >> n; |
| 323 | const int k0_start = stride_k == WARP_SIZE ? 0 : D2 - D2 % (2*stride_k); |
| 324 | const int k0_stop = D2 - D2 % (1*stride_k); |
| 325 | const int stride_i = WARP_SIZE / stride_k; |
| 326 | |
| 327 | if (k0_start == k0_stop) { |
| 328 | return; |
| 329 | } |
| 330 | |
| 331 | #pragma unroll |
| 332 | for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) { |
| 333 | const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); |
| 334 | |
| 335 | if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) { |
| 336 | break; |
| 337 | } |
| 338 | |
| 339 | #pragma unroll |
| 340 | for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { |
| 341 | const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); |
| 342 | |
| 343 | tile_KV[i*stride_tile + k] = KV[i*stride_KV + k]; |
| 344 | } |
| 345 | } |
| 346 | }; |
| 347 | ggml_cuda_unroll<3>{}(load); |
| 348 | } |
| 349 | } |
| 350 | |
| 351 | template<int ncols1, int nwarps, int nbatch_fa, bool use_cp_async> |
| 352 | static __device__ __forceinline__ void flash_attn_ext_f16_load_mask( |
| 353 | const half2 * const __restrict__ mask_h2, half2 * const __restrict__ tile_mask, const int stride_mask) { |
| 354 | static_assert(nbatch_fa == 2*WARP_SIZE || WARP_SIZE % nbatch_fa == 0, "bad KQ_per_iter" ); |
| 355 | |
| 356 | if (use_cp_async) { |
| 357 | constexpr int preload = nbatch_fa >= 32 ? nbatch_fa * sizeof(half) : 64; |
| 358 | constexpr int cols_per_warp = 8*WARP_SIZE/nbatch_fa; |
| 359 | constexpr int stride_j = nwarps * cols_per_warp; |
| 360 | |
| 361 | const unsigned int tile_mask_32 = ggml_cuda_cvta_generic_to_shared(generic_ptr: tile_mask); |
| 362 | |
| 363 | #pragma unroll |
| 364 | for (int j0 = 0; j0 < ncols1; j0 += stride_j) { |
| 365 | const int j = j0 + threadIdx.y*cols_per_warp + |
| 366 | (nbatch_fa == 2*WARP_SIZE ? threadIdx.x / (WARP_SIZE/4) : threadIdx.x / (WARP_SIZE/cols_per_warp)); |
| 367 | |
| 368 | if (j0 + stride_j > ncols1 && j >= ncols1) { |
| 369 | break; |
| 370 | } |
| 371 | |
| 372 | const int i = 4 * (threadIdx.x % (nbatch_fa/8)); |
| 373 | |
| 374 | cp_async_cg_16<preload>(tile_mask_32 + j*(nbatch_fa*sizeof(half) + 16) + i*sizeof(half2), mask_h2 + j*stride_mask + i); |
| 375 | } |
| 376 | return; |
| 377 | } |
| 378 | |
| 379 | constexpr int cols_per_warp = 2*WARP_SIZE/nbatch_fa; |
| 380 | constexpr int stride_j = nwarps * cols_per_warp; |
| 381 | #pragma unroll |
| 382 | for (int j0 = 0; j0 < ncols1; j0 += stride_j) { |
| 383 | const int j = j0 + threadIdx.y*cols_per_warp + (nbatch_fa == 2*WARP_SIZE ? 0 : threadIdx.x / (WARP_SIZE/cols_per_warp)); |
| 384 | |
| 385 | if (j0 + stride_j > ncols1 && j >= ncols1) { |
| 386 | break; |
| 387 | } |
| 388 | |
| 389 | const int i = nbatch_fa == 2*WARP_SIZE ? threadIdx.x : threadIdx.x % (WARP_SIZE/cols_per_warp); |
| 390 | |
| 391 | tile_mask[j*(nbatch_fa/2 + 4) + i] = mask_h2[j*stride_mask + i]; |
| 392 | } |
| 393 | } |
| 394 | |
| 395 | template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, |
| 396 | bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup, bool last_iter> |
| 397 | static __device__ __forceinline__ void flash_attn_ext_f16_iter( |
| 398 | const float2 * const __restrict__ Q_f2, |
| 399 | const half2 * const __restrict__ K_h2, |
| 400 | const half2 * const __restrict__ V_h2, |
| 401 | const half2 * const __restrict__ mask_h2, |
| 402 | float2 * const __restrict__ dstk, |
| 403 | float2 * const __restrict__ dstk_fixup, |
| 404 | const float scale, |
| 405 | const float slope, |
| 406 | const float logit_softcap, |
| 407 | const int ne01, |
| 408 | const int ne02, |
| 409 | const int stride_K, |
| 410 | const int stride_V, |
| 411 | const int stride_mask, |
| 412 | half2 * const __restrict__ tile_Q, |
| 413 | half2 * const __restrict__ tile_K, |
| 414 | half2 * const __restrict__ tile_V, |
| 415 | half2 * const __restrict__ tile_mask, |
| 416 | const tile_B * const __restrict__ Q_B, |
| 417 | tile_C_VKQ * const __restrict__ VKQ_C, |
| 418 | float * const __restrict__ KQ_max, |
| 419 | float * const __restrict__ KQ_rowsum, |
| 420 | const int kb0) { |
| 421 | #ifdef TURING_MMA_AVAILABLE |
| 422 | typedef fattn_mma_f16_config<DKQ, DV> c; |
| 423 | |
| 424 | #ifdef CP_ASYNC_AVAILABLE |
| 425 | constexpr int nstages = c::nstages_target; |
| 426 | #else |
| 427 | constexpr int nstages = 0; |
| 428 | #endif // CP_ASYNC_AVAILABLE |
| 429 | |
| 430 | constexpr int cols_per_warp = ntiles * tile_B::I; |
| 431 | constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles; |
| 432 | constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column. |
| 433 | constexpr int ncols = ncols1 * ncols2; |
| 434 | constexpr int nbatch_K2 = c::get_nbatch_K2_device(ncols); |
| 435 | constexpr int nbatch_V2 = c::get_nbatch_V2_device(ncols); |
| 436 | |
| 437 | constexpr int stride_tile_Q = DKQ/2 + 4; |
| 438 | constexpr int stride_tile_K = nbatch_K2 + 4; |
| 439 | |
| 440 | static_assert(!mla || nbatch_K2 >= nbatch_V2, "bad nbatch_K2, nbatch_V2 for MLA" ); |
| 441 | constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4; |
| 442 | |
| 443 | const int k_VKQ_0 = kb0 * c::nbatch_fa; |
| 444 | tile_C_KQ KQ_C[c::nbatch_fa/(np*tile_C_KQ::I) * ntiles]; |
| 445 | |
| 446 | // Use wide variants of tiles if ntiles >= 2. |
| 447 | tile_B_16 * Q_B_16 = (tile_B_16 *) Q_B; |
| 448 | tile_C_VKQ_16 * VKQ_C_16 = (tile_C_VKQ_16 *) VKQ_C; |
| 449 | tile_C_KQ_16 * KQ_C_16 = (tile_C_KQ_16 *) KQ_C; |
| 450 | |
| 451 | if constexpr (nstages > 1) { |
| 452 | static_assert(!mla, "multi-stage loading not implemented for MLA" ); |
| 453 | static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi stage loading" ); |
| 454 | constexpr bool use_cp_async = true; |
| 455 | cp_async_wait_all(); |
| 456 | __syncthreads(); |
| 457 | flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async> |
| 458 | (V_h2 + int64_t(k_VKQ_0)*stride_V, tile_V, nbatch_V2, stride_V); |
| 459 | } else { |
| 460 | constexpr bool use_cp_async = nstages == 1; |
| 461 | if (ncols2 > 1 || mask_h2) { |
| 462 | flash_attn_ext_f16_load_mask<ncols1, nwarps, c::nbatch_fa, use_cp_async>(mask_h2 + k_VKQ_0/2, tile_mask, stride_mask); |
| 463 | } |
| 464 | } |
| 465 | |
| 466 | #pragma unroll |
| 467 | for (int k0_start = 0; k0_start < DKQ/2; k0_start += nbatch_K2) { |
| 468 | const int k0_stop = k0_start + nbatch_K2 < DKQ/2 ? k0_start + nbatch_K2 : DKQ/2; |
| 469 | const int k0_diff = k0_stop - k0_start; |
| 470 | |
| 471 | if (nstages <= 1) { |
| 472 | constexpr bool use_cp_async = nstages == 1; |
| 473 | flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async> |
| 474 | (K_h2 + int64_t(k_VKQ_0)*stride_K + k0_start, tile_K, k0_diff, stride_K); |
| 475 | if (use_cp_async) { |
| 476 | cp_async_wait_all(); |
| 477 | } |
| 478 | __syncthreads(); |
| 479 | } |
| 480 | |
| 481 | // Calculate tile of KQ: |
| 482 | if constexpr (c::Q_in_reg) { |
| 483 | #pragma unroll |
| 484 | for (int i_KQ_00 = 0; i_KQ_00 < c::nbatch_fa; i_KQ_00 += np*tile_A::I) { |
| 485 | const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*tile_A::I; |
| 486 | #pragma unroll |
| 487 | for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += tile_A::J) { |
| 488 | tile_A K_A; |
| 489 | load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K); |
| 490 | if (ntiles == 1) { |
| 491 | mma(KQ_C[i_KQ_00/(np*tile_A::I)], K_A, Q_B[k_KQ_0/tile_A::J]); |
| 492 | } else { |
| 493 | #pragma unroll |
| 494 | for (int t = 0; t < ntiles/2; ++t) { |
| 495 | // Wide version of KQ_C is column-major => swap A and B. |
| 496 | mma(KQ_C_16[i_KQ_00/(np*tile_A::I) * ntiles/2 + t], Q_B_16[k_KQ_0/tile_A::J * ntiles/2 + t], K_A); |
| 497 | } |
| 498 | } |
| 499 | } |
| 500 | } |
| 501 | } else { |
| 502 | static_assert(ntiles == 2, "ntiles != 2 not implemented" ); |
| 503 | #pragma unroll |
| 504 | for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += tile_A::J) { |
| 505 | load_ldmatrix(Q_B_16[0], tile_Q + (threadIdx.y / np)*(tile_B_16::I*stride_tile_Q) + k_KQ_0, stride_tile_Q); |
| 506 | |
| 507 | #pragma unroll |
| 508 | for (int i_KQ_00 = 0; i_KQ_00 < c::nbatch_fa; i_KQ_00 += np*tile_A::I) { |
| 509 | const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*tile_A::I; |
| 510 | |
| 511 | tile_A K_A; |
| 512 | load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K); |
| 513 | |
| 514 | // Wide version of KQ_C is column-major => swap A and B. |
| 515 | mma(KQ_C_16[i_KQ_00/(np*tile_A::I)], Q_B_16[0], K_A); |
| 516 | } |
| 517 | } |
| 518 | } |
| 519 | |
| 520 | if (nstages <= 1) { |
| 521 | __syncthreads(); // Only needed if tile_K == tile_V. |
| 522 | } |
| 523 | } |
| 524 | |
| 525 | if (use_logit_softcap) { |
| 526 | static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size" ); |
| 527 | #pragma unroll |
| 528 | for (int i = 0; i < c::nbatch_fa/(np*tile_C_KQ::I) * ntiles; ++i) { |
| 529 | #pragma unroll |
| 530 | for (int l = 0; l < tile_C_KQ::ne; ++l) { |
| 531 | KQ_C[i].x[l] = logit_softcap*tanhf(KQ_C[i].x[l]); |
| 532 | } |
| 533 | } |
| 534 | } |
| 535 | |
| 536 | float KQ_max_new[cols_per_thread]; |
| 537 | #pragma unroll |
| 538 | for (int col = 0; col < cols_per_thread; ++col) { |
| 539 | KQ_max_new[col] = KQ_max[col]; |
| 540 | } |
| 541 | float KQ_rowsum_add[cols_per_thread] = {0.0f}; |
| 542 | |
| 543 | if (ntiles == 1) { |
| 544 | if (ncols2 > 1 || mask_h2) { |
| 545 | #pragma unroll |
| 546 | for (int i00 = 0; i00 < c::nbatch_fa; i00 += np*tile_C_KQ::I) { |
| 547 | const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ::I; |
| 548 | #pragma unroll |
| 549 | for (int l = 0; l < tile_C_KQ::ne; ++l) { |
| 550 | const int i = i0 + tile_C_KQ::get_i(l); |
| 551 | const int j = ((threadIdx.y / np)*tile_C_KQ::J + tile_C_KQ::get_j(l)) / ncols2; |
| 552 | |
| 553 | KQ_C[i00/(np*tile_C_KQ::I)].x[l] += slope * |
| 554 | __half2float(((const half *) tile_mask)[j*(c::nbatch_fa + 8) + i]); |
| 555 | } |
| 556 | } |
| 557 | } |
| 558 | |
| 559 | // Calculate softmax for each KQ column using the current max. value. |
| 560 | // The divisor is stored in KQ_rowsum and will be applied at the end. |
| 561 | static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size" ); |
| 562 | #pragma unroll |
| 563 | for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ::I); ++k) { |
| 564 | #pragma unroll |
| 565 | for (int l = 0; l < tile_C_KQ::ne; ++l) { |
| 566 | KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k].x[l]); |
| 567 | } |
| 568 | } |
| 569 | |
| 570 | // Values per KQ column are spread across 8 threads, does not need full warp reduce: |
| 571 | #pragma unroll |
| 572 | for (int col = 0; col < cols_per_thread; ++col) { |
| 573 | #pragma unroll |
| 574 | for (int offset = 16; offset >= 4; offset >>= 1) { |
| 575 | KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE)); |
| 576 | } |
| 577 | } |
| 578 | |
| 579 | static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size" ); |
| 580 | #pragma unroll |
| 581 | for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ::I); ++k) { |
| 582 | #pragma unroll |
| 583 | for (int l = 0; l < tile_C_KQ::ne; ++l) { |
| 584 | KQ_C[k].x[l] = expf(KQ_C[k].x[l] - KQ_max_new[l % 2]); |
| 585 | |
| 586 | KQ_rowsum_add[l % 2] += KQ_C[k].x[l]; |
| 587 | } |
| 588 | } |
| 589 | } else { // ntiles > 1 |
| 590 | if (ncols2 > 1 || mask_h2) { |
| 591 | #pragma unroll |
| 592 | for (int i00 = 0; i00 < c::nbatch_fa; i00 += np*tile_C_KQ_16::J) { |
| 593 | const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ_16::J; |
| 594 | #pragma unroll |
| 595 | for (int t = 0; t < ntiles/2; ++t) { |
| 596 | #pragma unroll |
| 597 | for (int l0 = 0; l0 < tile_C_KQ_16::ne; l0 += 2) { |
| 598 | const int i = (i0 + tile_C_KQ_16::get_j(l0)) / 2; |
| 599 | const int j = ((threadIdx.y / np)*cols_per_warp + t*tile_C_KQ_16::I + tile_C_KQ_16::get_i(l0)) / ncols2; |
| 600 | |
| 601 | const float2 tmp = __half22float2(tile_mask[j*(c::nbatch_fa/2 + 4) + i]); |
| 602 | const int KQ_index = i00/(np*tile_C_KQ_16::J) * ntiles/2 + t; |
| 603 | KQ_C_16[KQ_index].x[l0 + 0] += slope*tmp.x; |
| 604 | KQ_C_16[KQ_index].x[l0 + 1] += slope*tmp.y; |
| 605 | } |
| 606 | } |
| 607 | } |
| 608 | } |
| 609 | |
| 610 | // Calculate softmax for each KQ column using the current max. value. |
| 611 | // The divisor is stored in KQ_rowsum and will be applied at the end. |
| 612 | static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size" ); |
| 613 | #pragma unroll |
| 614 | for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ_16::J); ++k) { |
| 615 | #pragma unroll |
| 616 | for (int t = 0; t < ntiles/2; ++t) { |
| 617 | #pragma unroll |
| 618 | for (int l = 0; l < tile_C_KQ_16::ne; ++l) { |
| 619 | const int KQ_index = 2*t + (l/2) % 2; |
| 620 | KQ_max_new[KQ_index] = fmaxf(KQ_max_new[KQ_index], KQ_C_16[k*ntiles/2 + t].x[l]); |
| 621 | } |
| 622 | } |
| 623 | } |
| 624 | |
| 625 | // Values per KQ column are spread across 4 threads, does not need full warp reduce: |
| 626 | #pragma unroll |
| 627 | for (int col = 0; col < cols_per_thread; ++col) { |
| 628 | #pragma unroll |
| 629 | for (int offset = 2; offset >= 1; offset >>= 1) { |
| 630 | KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE)); |
| 631 | } |
| 632 | } |
| 633 | |
| 634 | static_assert(c::nbatch_fa % (np*tile_C_KQ_16::J) == 0, "bad loop size" ); |
| 635 | #pragma unroll |
| 636 | for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ_16::J); ++k) { |
| 637 | #pragma unroll |
| 638 | for (int t = 0; t < ntiles/2; ++t) { |
| 639 | #pragma unroll |
| 640 | for (int l = 0; l < tile_C_KQ_16::ne; ++l) { |
| 641 | const int KQ_index = 2*t + (l/2) % 2; |
| 642 | |
| 643 | KQ_C_16[k*ntiles/2 + t].x[l] = expf(KQ_C_16[k*ntiles/2 + t].x[l] - KQ_max_new[KQ_index]); |
| 644 | |
| 645 | KQ_rowsum_add[KQ_index] += KQ_C_16[k*ntiles/2 + t].x[l]; |
| 646 | } |
| 647 | } |
| 648 | } |
| 649 | } |
| 650 | |
| 651 | { |
| 652 | float KQ_max_scale[cols_per_thread]; |
| 653 | #pragma unroll |
| 654 | for (int col = 0; col < cols_per_thread; ++col) { |
| 655 | const float KQ_max_diff = KQ_max[col] - KQ_max_new[col]; |
| 656 | KQ_max_scale[col] = expf(KQ_max_diff); |
| 657 | KQ_max[col] = KQ_max_new[col]; |
| 658 | |
| 659 | *((uint32_t *) &KQ_max_scale[col]) *= KQ_max_diff >= SOFTMAX_FTZ_THRESHOLD; |
| 660 | |
| 661 | // Scale previous KQ_rowsum to account for a potential increase in KQ_max: |
| 662 | KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_rowsum_add[col]; |
| 663 | } |
| 664 | |
| 665 | if (ntiles == 1) { |
| 666 | const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]); |
| 667 | #pragma unroll |
| 668 | for (int i = 0; i < DV/tile_C_VKQ::I; ++i) { |
| 669 | #pragma unroll |
| 670 | for (int l = 0; l < tile_C_VKQ::ne; ++l) { |
| 671 | VKQ_C[i].x[l] *= KQ_max_scale_h2; |
| 672 | } |
| 673 | } |
| 674 | } else { |
| 675 | #pragma unroll |
| 676 | for (int col = 0; col < cols_per_thread; ++col) { |
| 677 | const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]); |
| 678 | #pragma unroll |
| 679 | for (int i = 0; i < DV/tile_C_VKQ_16::J; ++i) { |
| 680 | #pragma unroll |
| 681 | for (int l0 = 0; l0 < tile_C_VKQ_16::ne; l0 += 2) { |
| 682 | VKQ_C_16[i*ntiles/2 + col/2].x[l0 + col % 2] *= KQ_max_scale_h2; |
| 683 | } |
| 684 | } |
| 685 | } |
| 686 | } |
| 687 | } |
| 688 | |
| 689 | // Convert KQ C tiles into B tiles for VKQ calculation: |
| 690 | tile_B B[c::nbatch_fa/(np*2*tile_B::J) * ntiles]; |
| 691 | tile_B_16 * B_16 = (tile_B_16 *) B; |
| 692 | static_assert(c::nbatch_fa % (np*2*tile_B::J) == 0, "bad loop size" ); |
| 693 | if (ntiles == 1) { |
| 694 | #pragma unroll |
| 695 | for (int k = 0; k < c::nbatch_fa/(np*2*tile_B::J); ++k) { |
| 696 | B[k] = get_transposed(get_half2(KQ_C[k])); |
| 697 | } |
| 698 | } else { |
| 699 | for (int k = 0; k < c::nbatch_fa/(np*2*tile_B_16::J); ++k) { |
| 700 | #pragma unroll |
| 701 | for (int t = 0; t < ntiles/2; ++t) { |
| 702 | B_16[k*ntiles/2 + t] = get_half2(KQ_C_16[k*ntiles/2 + t]); |
| 703 | } |
| 704 | } |
| 705 | } |
| 706 | |
| 707 | if (nstages > 1) { |
| 708 | // Preload K tile for next iteration: |
| 709 | constexpr bool use_cp_async = true; |
| 710 | cp_async_wait_all(); |
| 711 | __syncthreads(); |
| 712 | if (!last_iter) { |
| 713 | if (ncols2 > 1 || mask_h2) { |
| 714 | flash_attn_ext_f16_load_mask<ncols1, nwarps, c::nbatch_fa, use_cp_async> |
| 715 | (mask_h2 + (k_VKQ_0 + c::nbatch_fa)/2, tile_mask, stride_mask); |
| 716 | } |
| 717 | flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async> |
| 718 | (K_h2 + int64_t(k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K); |
| 719 | } |
| 720 | } |
| 721 | |
| 722 | |
| 723 | // For MLA K and V have the same data. |
| 724 | // Therefore, iterate over V in reverse and re-use the data if possible. |
| 725 | static_assert(!mla || nstages <= 1, "combination of MLA and multi-stage loading not implemented" ); |
| 726 | constexpr int reusable_cutoff = mla ? (DKQ - 1) - (DKQ - 1) % (2*nbatch_K2) - (DKQ - DV) : DV; |
| 727 | #pragma unroll |
| 728 | for (int i0_stop = DV; i0_stop > 0; i0_stop -= 2*nbatch_V2) { |
| 729 | const int i0_start = i0_stop - 2*nbatch_V2 > 0 ? i0_stop - 2*nbatch_V2 : 0; |
| 730 | const int i0_diff = i0_stop - i0_start; |
| 731 | |
| 732 | if (nstages <= 1 && i0_start < reusable_cutoff) { |
| 733 | constexpr bool use_cp_async = nstages == 1; |
| 734 | flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async> |
| 735 | (V_h2 + int64_t(k_VKQ_0)*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V); |
| 736 | if (use_cp_async) { |
| 737 | cp_async_wait_all(); |
| 738 | } |
| 739 | __syncthreads(); |
| 740 | } |
| 741 | const half2 * tile_V_i = i0_start < reusable_cutoff ? tile_V : tile_V + (i0_start - reusable_cutoff)/2; |
| 742 | |
| 743 | // Calculate VKQ tile: |
| 744 | #pragma unroll |
| 745 | for (int i_VKQ_0 = i0_start; i_VKQ_0 < i0_stop; i_VKQ_0 += tile_C_VKQ::I) { |
| 746 | static_assert((c::nbatch_fa/2) % (np*tile_A::J) == 0, "bad loop size" ); |
| 747 | #pragma unroll |
| 748 | for (int k00 = 0; k00 < c::nbatch_fa/2; k00 += np*tile_A::J) { |
| 749 | const int k0 = k00 + (threadIdx.y % np)*tile_A::J; |
| 750 | |
| 751 | tile_A A; |
| 752 | load_ldmatrix_trans(A, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V); |
| 753 | if (ntiles == 1) { |
| 754 | mma(VKQ_C[i_VKQ_0/tile_C_VKQ::I], A, B[k00/(np*tile_A::J)]); |
| 755 | } else { |
| 756 | #pragma unroll |
| 757 | for (int t = 0; t < ntiles/2; ++t) { |
| 758 | // Wide version of VKQ_C is column-major => swap A and B. |
| 759 | mma(VKQ_C_16[i_VKQ_0/tile_C_VKQ::I * ntiles/2 + t], B_16[k00/(np*tile_A::J) * ntiles/2 + t], A); |
| 760 | } |
| 761 | } |
| 762 | } |
| 763 | } |
| 764 | |
| 765 | if (nstages <= 1) { |
| 766 | __syncthreads(); // Only needed if tile_K == tile_V. |
| 767 | } |
| 768 | } |
| 769 | #else |
| 770 | GGML_UNUSED_VARS(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, |
| 771 | scale, slope, logit_softcap, ne01, ne02, |
| 772 | stride_K, stride_V, stride_mask, |
| 773 | tile_Q, tile_K, tile_V, tile_mask, |
| 774 | Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0); |
| 775 | NO_DEVICE_CODE; |
| 776 | #endif // TURING_MMA_AVAILABLE |
| 777 | } |
| 778 | |
| 779 | template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup> |
| 780 | static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( |
| 781 | const float2 * const __restrict__ Q_f2, |
| 782 | const half2 * const __restrict__ K_h2, |
| 783 | const half2 * const __restrict__ V_h2, |
| 784 | const half2 * const __restrict__ mask_h2, |
| 785 | const float * const __restrict__ sinks_f, |
| 786 | float2 * const __restrict__ dstk, |
| 787 | float2 * const __restrict__ dstk_fixup, |
| 788 | const float scale, |
| 789 | const float slope, |
| 790 | const float logit_softcap, |
| 791 | const int ne01, |
| 792 | const int ne02, |
| 793 | const int stride_Q1, |
| 794 | const int stride_Q2, |
| 795 | const int stride_K, |
| 796 | const int stride_V, |
| 797 | const int stride_mask, |
| 798 | const int jt, |
| 799 | const int kb0_start, |
| 800 | const int kb0_stop) { |
| 801 | #ifdef TURING_MMA_AVAILABLE |
| 802 | //In this kernel Q, K, V are matrices while i, j, k are matrix indices. |
| 803 | |
| 804 | typedef fattn_mma_f16_config<DKQ, DV> c; |
| 805 | |
| 806 | #ifdef CP_ASYNC_AVAILABLE |
| 807 | constexpr int nstages = c::nstages_target; |
| 808 | #else |
| 809 | constexpr int nstages = 0; |
| 810 | #endif // CP_ASYNC_AVAILABLE |
| 811 | |
| 812 | constexpr int ncols = ncols1 * ncols2; |
| 813 | constexpr int cols_per_warp = ntiles * tile_B::I; |
| 814 | constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles; |
| 815 | constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column. |
| 816 | constexpr int nbatch_K2 = c::get_nbatch_K2_device(ncols); |
| 817 | constexpr int nbatch_V2 = c::get_nbatch_V2_device(ncols); |
| 818 | |
| 819 | static_assert(nwarps * (cols_per_warp/ncols2) % ncols1 == 0, "bad nwarps" ); |
| 820 | |
| 821 | constexpr int stride_tile_Q = DKQ/2 + 4; |
| 822 | constexpr int stride_tile_K = nbatch_K2 + 4; |
| 823 | |
| 824 | static_assert(!mla || nbatch_K2 >= nbatch_V2, "bad nbatch_K2, nbatch_V2 for MLA" ); |
| 825 | constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4; |
| 826 | constexpr int stride_tile_KV_max = stride_tile_K > stride_tile_V ? stride_tile_K : stride_tile_V; |
| 827 | |
| 828 | extern __shared__ half2 tile_Q[]; |
| 829 | half2 * tile_K = c::Q_in_reg ? tile_Q : tile_Q + ncols * stride_tile_Q; |
| 830 | half2 * tile_V = nstages > 1 ? tile_K + c::nbatch_fa * stride_tile_K : tile_K; |
| 831 | half2 * tile_mask = nstages > 1 ? tile_V + c::nbatch_fa * stride_tile_V : tile_V + c::nbatch_fa * stride_tile_KV_max; |
| 832 | |
| 833 | tile_B Q_B[(c::Q_in_reg ? DKQ/(2*tile_B::J) : 1) * ntiles]; |
| 834 | tile_C_VKQ VKQ_C[DV/tile_C_VKQ::I * ntiles]; |
| 835 | |
| 836 | tile_B_16 * Q_B_16 = (tile_B_16 *) Q_B; |
| 837 | tile_C_VKQ_16 * VKQ_C_16 = (tile_C_VKQ_16 *) VKQ_C; |
| 838 | |
| 839 | float KQ_rowsum[cols_per_thread] = {0.0f}; |
| 840 | float KQ_max[cols_per_thread]; |
| 841 | #pragma unroll |
| 842 | for (int col = 0; col < cols_per_thread; ++col) { |
| 843 | KQ_max[col] = -FLT_MAX/2.0f; |
| 844 | } |
| 845 | |
| 846 | // Load Q data into tile_Q, either temporarily or permanently. |
| 847 | // Q in registers is faster, but register pressure is the biggest bottleneck. |
| 848 | // The loading is done with decreasing granularity for D for better memory bandwidth. |
| 849 | const half2 scale_h2 = make_half2(scale, scale); |
| 850 | #pragma unroll |
| 851 | for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { |
| 852 | const int k0_start = stride_k == WARP_SIZE ? 0 : DKQ/2 - (DKQ/2) % (2*stride_k); |
| 853 | const int k0_stop = DKQ/2 - (DKQ/2) % (1*stride_k); |
| 854 | const int stride_jc = WARP_SIZE / stride_k; |
| 855 | |
| 856 | if (k0_start == k0_stop) { |
| 857 | continue; |
| 858 | } |
| 859 | |
| 860 | #pragma unroll |
| 861 | for (int jc0 = 0; jc0 < ncols; jc0 += nwarps*stride_jc) { |
| 862 | const int jc = jc0 + threadIdx.y*stride_jc + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); |
| 863 | |
| 864 | if (jc0 + nwarps*stride_jc > ncols && jc >= ncols) { |
| 865 | break; |
| 866 | } |
| 867 | |
| 868 | const int j = jc / ncols2; |
| 869 | const int c = jc % ncols2; |
| 870 | |
| 871 | if (jt*ncols1 + j < ne01) { |
| 872 | #pragma unroll |
| 873 | for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { |
| 874 | const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); |
| 875 | |
| 876 | const float2 tmp = Q_f2[(jt*ncols1 + j)*stride_Q1 + c*stride_Q2 + k]; |
| 877 | tile_Q[jc*stride_tile_Q + k] = scale_h2 * make_half2(tmp.x, tmp.y); |
| 878 | } |
| 879 | } else { |
| 880 | #pragma unroll |
| 881 | for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { |
| 882 | const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); |
| 883 | |
| 884 | tile_Q[jc*stride_tile_Q + k] = make_half2(0.0f, 0.0f); |
| 885 | } |
| 886 | } |
| 887 | } |
| 888 | } |
| 889 | |
| 890 | __syncthreads(); |
| 891 | |
| 892 | if (c::Q_in_reg) { |
| 893 | const int j0 = (threadIdx.y / np) * cols_per_warp; |
| 894 | |
| 895 | #pragma unroll |
| 896 | for (int k0 = 0; k0 < DKQ/2; k0 += tile_B::J) { |
| 897 | if (ntiles == 1) { |
| 898 | load_ldmatrix(Q_B[k0/tile_B::J], tile_Q + j0*stride_tile_Q + k0, stride_tile_Q); |
| 899 | } else { |
| 900 | #pragma unroll |
| 901 | for (int t = 0; t < ntiles/2; ++t) { |
| 902 | load_ldmatrix(Q_B_16[k0/tile_B_16::J * ntiles/2 + t], |
| 903 | tile_Q + (j0 + t*tile_B_16::I)*stride_tile_Q + k0, stride_tile_Q); |
| 904 | } |
| 905 | } |
| 906 | } |
| 907 | } |
| 908 | |
| 909 | __syncthreads(); |
| 910 | |
| 911 | // Preload mask and K data for first iteration when using cp_async with multiple stages: |
| 912 | if constexpr (nstages > 1) { |
| 913 | static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi-stage pipeline" ); |
| 914 | constexpr bool use_cp_async = true; |
| 915 | if (ncols2 > 1 || mask_h2) { |
| 916 | flash_attn_ext_f16_load_mask<ncols1, nwarps, c::nbatch_fa, use_cp_async> |
| 917 | (mask_h2 + kb0_start*c::nbatch_fa/2, tile_mask, stride_mask); |
| 918 | } |
| 919 | flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async> |
| 920 | (K_h2 + int64_t(kb0_start)*c::nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K); |
| 921 | } |
| 922 | |
| 923 | // Iterate over ne11 == previous tokens: |
| 924 | int kb0 = kb0_start; |
| 925 | for (; kb0 < kb0_stop-1; ++kb0) { |
| 926 | constexpr bool last_iter = false; |
| 927 | flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter> |
| 928 | (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap, |
| 929 | ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0); |
| 930 | } |
| 931 | { // kb0_start is always < kb0_stop so the last iter can be executed unconditionally. |
| 932 | constexpr bool last_iter = true; |
| 933 | flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter> |
| 934 | (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap, |
| 935 | ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0); |
| 936 | } |
| 937 | |
| 938 | // With multi-stage loading there is no __syncthreads at the end of the iter, |
| 939 | // there can be a race condition on shared memory access for combining/writing back results. |
| 940 | if (nstages > 1 && nwarps*cols_per_warp > c::nbatch_fa) { |
| 941 | __syncthreads(); |
| 942 | } |
| 943 | |
| 944 | // Finally, sum up partial KQ rowsums. |
| 945 | // The partial sums are spread across 8/4 threads each, does not need full reduce. |
| 946 | { |
| 947 | constexpr int offset_first = ntiles == 1 ? 16 : 2; |
| 948 | constexpr int offset_last = ntiles == 1 ? 4 : 1; |
| 949 | #pragma unroll |
| 950 | for (int col = 0; col < cols_per_thread; ++col) { |
| 951 | #pragma unroll |
| 952 | for (int offset = offset_first; offset >= offset_last; offset >>= 1) { |
| 953 | KQ_rowsum[col] += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum[col], offset, WARP_SIZE); |
| 954 | } |
| 955 | } |
| 956 | } |
| 957 | |
| 958 | // If attention sinks are used, potentially re-scale if KQ_max is small. |
| 959 | // Also add the sink as a value to KQ_rowsum, this is done after synchonization of KQ_rowsum |
| 960 | // so it's being done unconditionally for every thread. |
| 961 | if (!is_fixup && (np == 1 || threadIdx.y % np == 0) && sinks_f) { |
| 962 | float KQ_max_scale[cols_per_thread]; |
| 963 | #pragma unroll |
| 964 | for (int col = 0; col < cols_per_thread; ++col) { |
| 965 | static_assert(ntiles == 1 || ntiles == 2, "ntiles > 2 not implemented" ); |
| 966 | const int jc = ntiles == 1 ? 2*tile_C_VKQ::get_j(col/2) + col % 2 : tile_C_VKQ_16::get_i(col); |
| 967 | const float sink = sinks_f[jc % ncols2]; |
| 968 | |
| 969 | const float KQ_max_new = fmaxf(KQ_max[col], sink); |
| 970 | const float KQ_max_diff = KQ_max[col] - KQ_max_new; |
| 971 | KQ_max_scale[col] = expf(KQ_max_diff); |
| 972 | KQ_max[col] = KQ_max_new; |
| 973 | |
| 974 | *((uint32_t *) &KQ_max_scale[col]) *= KQ_max_diff >= SOFTMAX_FTZ_THRESHOLD; |
| 975 | |
| 976 | const float KQ_max_add = expf(sink - KQ_max_new); |
| 977 | KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_max_add; |
| 978 | } |
| 979 | |
| 980 | if (ntiles == 1) { |
| 981 | const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]); |
| 982 | #pragma unroll |
| 983 | for (int i = 0; i < DV/tile_C_VKQ::I; ++i) { |
| 984 | #pragma unroll |
| 985 | for (int l = 0; l < tile_C_VKQ::ne; ++l) { |
| 986 | VKQ_C[i].x[l] *= KQ_max_scale_h2; |
| 987 | } |
| 988 | } |
| 989 | } else { |
| 990 | #pragma unroll |
| 991 | for (int col = 0; col < cols_per_thread; ++col) { |
| 992 | const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]); |
| 993 | #pragma unroll |
| 994 | for (int i = 0; i < DV/tile_C_VKQ_16::J; ++i) { |
| 995 | #pragma unroll |
| 996 | for (int l0 = 0; l0 < tile_C_VKQ_16::ne; l0 += 2) { |
| 997 | VKQ_C_16[i*ntiles/2 + col/2].x[l0 + col % 2] *= KQ_max_scale_h2; |
| 998 | } |
| 999 | } |
| 1000 | } |
| 1001 | } |
| 1002 | } |
| 1003 | |
| 1004 | // Combine VKQ accumulator values if np > 1. |
| 1005 | // It's also faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM. |
| 1006 | // So also write VKQ accumulators to shared memory in column-major format if np == 1. |
| 1007 | |
| 1008 | constexpr int nbatch_combine = c::get_nbatch_combine_device(ncols); |
| 1009 | constexpr int tile_stride = nbatch_combine + 4; |
| 1010 | static_assert((DV/2) % nbatch_combine == 0, "bad nbatch_combine" ); |
| 1011 | |
| 1012 | if constexpr (ntiles == 1) { |
| 1013 | const int jc_cwmo = (threadIdx.x % (2*tile_C_VKQ::J)) / tile_C_VKQ::J; // jc combine write meta offset |
| 1014 | const int jc_cwm = threadIdx.y*(2*tile_C_VKQ::J) + 2*tile_C_VKQ::get_j(-1) + jc_cwmo; // jc combine write meta |
| 1015 | const float2 KQ_cmr = make_float2(KQ_max[jc_cwmo], KQ_rowsum[jc_cwmo]); // KQ combine max rowsum |
| 1016 | |
| 1017 | if (((!needs_fixup && !is_fixup) || np > 1) && threadIdx.x < 2*tile_C_VKQ::J) { |
| 1018 | // Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale. |
| 1019 | ((float2 *) tile_Q)[jc_cwm*(tile_stride/2) + nbatch_combine/2] = KQ_cmr; |
| 1020 | } |
| 1021 | |
| 1022 | __syncthreads(); |
| 1023 | |
| 1024 | if (np == 1) { |
| 1025 | // No combination is needed, the meta data can be directly written from registers to VRAM. |
| 1026 | if (needs_fixup && threadIdx.x < tile_B::I) { |
| 1027 | float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols; |
| 1028 | dstk_fixup_meta[jc_cwm] = KQ_cmr; |
| 1029 | } |
| 1030 | if (is_fixup && threadIdx.x < tile_B::I) { |
| 1031 | float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols; |
| 1032 | dstk_fixup_meta[jc_cwm] = KQ_cmr; |
| 1033 | } |
| 1034 | } |
| 1035 | } else { |
| 1036 | static_assert(ntiles == 2 || ntiles == 4, "bad ntiles" ); |
| 1037 | const int jc_cwm = threadIdx.y*cols_per_warp // jc combine write meta |
| 1038 | + (ntiles == 4 ? ((threadIdx.x % 4) / 2) * tile_C_VKQ_16::I : 0) |
| 1039 | + tile_C_VKQ_16::get_i(threadIdx.x % 4); |
| 1040 | const float2 KQ_cmr = make_float2(KQ_max[threadIdx.x % cols_per_thread], KQ_rowsum[threadIdx.x % cols_per_thread]); // KQ combine max rowsum |
| 1041 | |
| 1042 | if (((!needs_fixup && !is_fixup) || np > 1) && (ntiles == 4 || threadIdx.x % 4 < cols_per_thread)) { |
| 1043 | // Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale. |
| 1044 | ((float2 *) tile_Q)[jc_cwm*(tile_stride/2) + nbatch_combine/2] = KQ_cmr; |
| 1045 | } |
| 1046 | |
| 1047 | __syncthreads(); |
| 1048 | |
| 1049 | if (np == 1) { |
| 1050 | // No combination is needed, the meta data can be directly written from registers to VRAM. |
| 1051 | if (needs_fixup && (ntiles == 4 || threadIdx.x % 4 < ntiles)) { |
| 1052 | float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols; |
| 1053 | dstk_fixup_meta[jc_cwm] = KQ_cmr; |
| 1054 | } |
| 1055 | if (is_fixup && (ntiles == 4 || threadIdx.x % 4 < ntiles)) { |
| 1056 | float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols; |
| 1057 | dstk_fixup_meta[jc_cwm] = KQ_cmr; |
| 1058 | } |
| 1059 | } |
| 1060 | } |
| 1061 | |
| 1062 | static_assert(np == 1 || ntiles == 1 || ntiles == 2, "bad ntiles" ); |
| 1063 | if (np > 1 && threadIdx.y % np == 0) { |
| 1064 | // Combine the meta data for parallel warps via shared memory. |
| 1065 | // Warps with threadIdx.y % np != 0 must NOT return early. |
| 1066 | // All threads must return simultaneously to avoid race conditions with work on the next tile. |
| 1067 | |
| 1068 | constexpr int nmeta = np*cols_per_warp >= WARP_SIZE ? np*cols_per_warp/WARP_SIZE : 1; |
| 1069 | |
| 1070 | const int jc_meta = threadIdx.y*cols_per_warp + (np*cols_per_warp < WARP_SIZE ? threadIdx.x % (np*cols_per_warp) : threadIdx.x); |
| 1071 | float2 * const meta_ptr = ((float2 *) tile_Q) + jc_meta*(tile_stride/2) + nbatch_combine/2; |
| 1072 | float2 meta[nmeta]; |
| 1073 | #pragma unroll |
| 1074 | for (int imeta = 0; imeta < nmeta; ++imeta) { |
| 1075 | meta[imeta] = meta_ptr[imeta * WARP_SIZE * tile_stride/2]; |
| 1076 | } |
| 1077 | |
| 1078 | float KQ_cmn = meta[0].x; // KQ combine max new, max between all parallel warps. |
| 1079 | #pragma unroll |
| 1080 | for (int imeta = 1; imeta < nmeta; ++imeta) { |
| 1081 | KQ_cmn = fmaxf(KQ_cmn, meta[imeta].x); |
| 1082 | } |
| 1083 | #pragma unroll |
| 1084 | for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) { |
| 1085 | if (offset < WARP_SIZE) { |
| 1086 | KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, WARP_SIZE)); |
| 1087 | } |
| 1088 | } |
| 1089 | |
| 1090 | float KQ_cms[nmeta]; // KQ combine max scale per warp. |
| 1091 | #pragma unroll |
| 1092 | for (int imeta = 0; imeta < nmeta; ++imeta) { |
| 1093 | KQ_cms[imeta] = expf(meta[imeta].x - KQ_cmn); |
| 1094 | } |
| 1095 | |
| 1096 | float KQ_crs = KQ_cms[0]*meta[0].y; // KQ combine rowsum, scaled sum of all parallel warps. |
| 1097 | #pragma unroll |
| 1098 | for (int imeta = 1; imeta < nmeta; ++imeta) { |
| 1099 | KQ_crs += KQ_cms[imeta]*meta[imeta].y; |
| 1100 | } |
| 1101 | #pragma unroll |
| 1102 | for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) { |
| 1103 | if (offset < WARP_SIZE) { |
| 1104 | KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE); |
| 1105 | } |
| 1106 | } |
| 1107 | |
| 1108 | __syncthreads(); |
| 1109 | |
| 1110 | // Write back combined meta data: |
| 1111 | #pragma unroll |
| 1112 | for (int imeta = 0; imeta < nmeta; ++imeta) { |
| 1113 | if (np*cols_per_warp >= WARP_SIZE || threadIdx.x < np*cols_per_warp) { |
| 1114 | // Combined KQ max scale + rowsum. |
| 1115 | meta_ptr[imeta * WARP_SIZE * tile_stride/2] = make_float2(KQ_cms[imeta], KQ_crs); |
| 1116 | } |
| 1117 | } |
| 1118 | |
| 1119 | // Combined KQ max + rowsum. |
| 1120 | static_assert(cols_per_warp <= WARP_SIZE); |
| 1121 | if (needs_fixup && (cols_per_warp == WARP_SIZE || threadIdx.x < cols_per_warp)) { |
| 1122 | float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols; |
| 1123 | dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs); |
| 1124 | } |
| 1125 | if (is_fixup && (cols_per_warp == WARP_SIZE || threadIdx.x < cols_per_warp)) { |
| 1126 | float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols; |
| 1127 | dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs); |
| 1128 | } |
| 1129 | } else if (np > 1) { |
| 1130 | // Warps with threadIdx.y % np == 0 execute a __syncthreads() in the if branch. |
| 1131 | // Therefore, all other warps also need to execute a __syncthreads(). |
| 1132 | // Otherwise the points at which warps synchronize with each other would become misaligned. |
| 1133 | __syncthreads(); |
| 1134 | } |
| 1135 | |
| 1136 | #pragma unroll |
| 1137 | for (int k00 = 0; k00 < DV/2; k00 += nbatch_combine) { |
| 1138 | if (ntiles == 1) { |
| 1139 | const int jc_cwd = threadIdx.y*tile_B::I + tile_B::get_i(-1); // jc combine write data |
| 1140 | #pragma unroll |
| 1141 | for (int k0 = 0; k0 < nbatch_combine; k0 += tile_B::J) { |
| 1142 | const tile_B B = get_transposed(VKQ_C[(k00 + k0)/tile_B::J]); // Conversion of C to B matrix puts it in column-major format. |
| 1143 | |
| 1144 | #pragma unroll |
| 1145 | for (int l = 0; l < tile_B::ne; ++l) { |
| 1146 | const int k = k0 + tile_B::get_j(l); |
| 1147 | |
| 1148 | tile_Q[jc_cwd*tile_stride + k] = B.x[l]; |
| 1149 | } |
| 1150 | } |
| 1151 | } else { |
| 1152 | #pragma unroll |
| 1153 | for (int t = 0; t < ntiles/2; ++t) { |
| 1154 | const int j0 = threadIdx.y*cols_per_warp + t*tile_C_VKQ_16::I; |
| 1155 | #pragma unroll |
| 1156 | for (int k0 = 0; k0 < nbatch_combine; k0 += tile_C_VKQ_16::J) { |
| 1157 | #pragma unroll |
| 1158 | for (int l = 0; l < tile_C_VKQ_16::ne; ++l) { |
| 1159 | const int j = j0 + tile_C_VKQ_16::get_i(l); |
| 1160 | const int k = k0 + tile_C_VKQ_16::get_j(l); |
| 1161 | |
| 1162 | tile_Q[j*tile_stride + k] = VKQ_C_16[(k00 + k0)/tile_C_VKQ_16::J * ntiles/2 + t].x[l]; |
| 1163 | } |
| 1164 | } |
| 1165 | } |
| 1166 | } |
| 1167 | |
| 1168 | __syncthreads(); |
| 1169 | |
| 1170 | if (np == 1 || threadIdx.y % np == 0) { |
| 1171 | // The first 2*2*gridDim.x*ncols floats in dstk_fixup are for storing max. values and row sums. |
| 1172 | // The values after that are for the partial results of the individual blocks. |
| 1173 | float2 * dstk_fixup_data = dstk_fixup + gridDim.x*(2*ncols) + blockIdx.x*(ncols*(DV/2)); |
| 1174 | |
| 1175 | #pragma unroll |
| 1176 | for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { |
| 1177 | const int k0_start = stride_k == WARP_SIZE ? 0 : nbatch_combine - nbatch_combine % (2*stride_k); |
| 1178 | const int k0_stop = nbatch_combine - nbatch_combine % (1*stride_k); |
| 1179 | const int stride_jc = WARP_SIZE / stride_k; |
| 1180 | |
| 1181 | if (k0_start == k0_stop) { |
| 1182 | continue; |
| 1183 | } |
| 1184 | |
| 1185 | #pragma unroll |
| 1186 | for (int jc0_dst = 0; jc0_dst < ncols; jc0_dst += (nwarps/np)*stride_jc) { |
| 1187 | const int jc_dst = jc0_dst + (threadIdx.y/np)*stride_jc + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); |
| 1188 | |
| 1189 | if (jc0_dst + (nwarps/np)*stride_jc > ncols && jc_dst >= ncols) { |
| 1190 | break; |
| 1191 | } |
| 1192 | |
| 1193 | const int jc_tile_K = (jc_dst/cols_per_warp)*(np*cols_per_warp) + jc_dst % cols_per_warp; |
| 1194 | |
| 1195 | const int j_dst = jc_dst / ncols2; |
| 1196 | const int c_dst = jc_dst % ncols2; |
| 1197 | |
| 1198 | if (!is_fixup && jt*ncols1 + j_dst >= ne01) { |
| 1199 | continue; |
| 1200 | } |
| 1201 | |
| 1202 | const float * meta_j = (const float *) tile_Q + jc_tile_K*tile_stride + nbatch_combine; |
| 1203 | #pragma unroll |
| 1204 | for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { |
| 1205 | const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); |
| 1206 | |
| 1207 | float2 dstk_val = make_float2(0.0f, 0.0f); |
| 1208 | #pragma unroll |
| 1209 | for (int ip = 0; ip < np; ++ip) { |
| 1210 | const float KQ_crs = np == 1 ? 1.0f : meta_j[ip*cols_per_warp * tile_stride + 0]; |
| 1211 | const float2 dstk_val_add = __half22float2(tile_Q[(jc_tile_K + ip*cols_per_warp) * tile_stride + k]); |
| 1212 | dstk_val.x += dstk_val_add.x*KQ_crs; |
| 1213 | dstk_val.y += dstk_val_add.y*KQ_crs; |
| 1214 | } |
| 1215 | |
| 1216 | if (!needs_fixup && !is_fixup) { |
| 1217 | const float KQ_rowsum_j = meta_j[1]; |
| 1218 | dstk_val.x /= KQ_rowsum_j; |
| 1219 | dstk_val.y /= KQ_rowsum_j; |
| 1220 | } |
| 1221 | |
| 1222 | if (is_fixup) { |
| 1223 | dstk_fixup_data[jc_dst*(DV/2) + k00 + k] = dstk_val; |
| 1224 | } else { |
| 1225 | dstk[((jt*ncols1 + j_dst)*ne02 + c_dst)*(DV/2) + k00 + k] = dstk_val; |
| 1226 | } |
| 1227 | } |
| 1228 | } |
| 1229 | } |
| 1230 | } |
| 1231 | if (np > 1) { |
| 1232 | __syncthreads(); |
| 1233 | } |
| 1234 | } |
| 1235 | #else |
| 1236 | GGML_UNUSED_VARS(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dstk_fixup, |
| 1237 | scale, slope, logit_softcap, ne01, ne02, |
| 1238 | stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, |
| 1239 | jt, kb0_start, kb0_stop); |
| 1240 | NO_DEVICE_CODE; |
| 1241 | #endif // TURING_MMA_AVAILABLE |
| 1242 | } |
| 1243 | |
| 1244 | template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla> |
| 1245 | __launch_bounds__(nwarps*WARP_SIZE, 1) |
| 1246 | static __global__ void flash_attn_ext_f16( |
| 1247 | const char * __restrict__ Q, |
| 1248 | const char * __restrict__ K, |
| 1249 | const char * __restrict__ V, |
| 1250 | const char * __restrict__ mask, |
| 1251 | const char * __restrict__ sinks, |
| 1252 | const int * __restrict__ KV_max, |
| 1253 | float * __restrict__ dst, |
| 1254 | float2 * __restrict__ dst_meta, |
| 1255 | const float scale, |
| 1256 | const float max_bias, |
| 1257 | const float m0, |
| 1258 | const float m1, |
| 1259 | const uint32_t n_head_log2, |
| 1260 | const float logit_softcap, |
| 1261 | const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03, |
| 1262 | const int32_t nb01, const int32_t nb02, const int32_t nb03, |
| 1263 | const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, |
| 1264 | const int32_t nb11, const int32_t nb12, const int64_t nb13, |
| 1265 | const int32_t nb21, const int32_t nb22, const int64_t nb23, |
| 1266 | const int32_t ne31, const int32_t ne32, const int32_t ne33, |
| 1267 | const int32_t nb31, const int32_t nb32, const int64_t nb33) { |
| 1268 | #if defined(FLASH_ATTN_AVAILABLE) && defined(TURING_MMA_AVAILABLE) |
| 1269 | |
| 1270 | // Skip unused kernel variants for faster compilation: |
| 1271 | if (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) { |
| 1272 | NO_DEVICE_CODE; |
| 1273 | return; |
| 1274 | } |
| 1275 | #if __CUDA_ARCH__ == GGML_CUDA_CC_TURING |
| 1276 | if (ncols1*ncols2 > 32) { |
| 1277 | NO_DEVICE_CODE; |
| 1278 | return; |
| 1279 | } |
| 1280 | #endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING |
| 1281 | |
| 1282 | static_assert(!mla || DKQ >= DV, "MLA needs DKQ >= DV" ); |
| 1283 | |
| 1284 | typedef fattn_mma_f16_config<DKQ, DV> c; |
| 1285 | |
| 1286 | static_assert(FATTN_KQ_STRIDE % fattn_mma_f16_config<DKQ, DV>::nbatch_fa == 0, "bad nbatch_fa" ); |
| 1287 | |
| 1288 | const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. |
| 1289 | |
| 1290 | const int stride_Q1 = nb01 / sizeof(float2); |
| 1291 | const int stride_Q2 = nb02 / sizeof(float2); |
| 1292 | const int stride_K = nb11 / sizeof(half2); |
| 1293 | const int stride_mask = nb31 / sizeof(half2); |
| 1294 | |
| 1295 | const int stride_V = mla ? stride_K : nb21 / sizeof(half2); |
| 1296 | |
| 1297 | const int iter_k = ne11 / FATTN_KQ_STRIDE; |
| 1298 | const int iter_j = (ne01 + (ncols1 - 1)) / ncols1; |
| 1299 | |
| 1300 | constexpr int kb_niter = FATTN_KQ_STRIDE / c::nbatch_fa; // Number of kernel iterations per assigned KQ slice. |
| 1301 | |
| 1302 | // kbc == k block continuous, current index in continuous ijk space. |
| 1303 | int kbc = (blockIdx.x + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x; |
| 1304 | const int kbc_stop = (blockIdx.x + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x; |
| 1305 | |
| 1306 | // If the seams of 2 CUDA blocks fall within an output tile their results need to be combined. |
| 1307 | // For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup). |
| 1308 | // In the most general case >2 seams can fall into the same tile. |
| 1309 | |
| 1310 | // kb0 == k start index when in the output tile. |
| 1311 | int kb0_start = kbc % iter_k; |
| 1312 | int kb0_stop = min(iter_k, kb0_start + kbc_stop - kbc); |
| 1313 | |
| 1314 | while (kbc < kbc_stop && kb0_stop == iter_k) { |
| 1315 | const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2)); |
| 1316 | const int zt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); // head in units of ncols2 |
| 1317 | const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile. |
| 1318 | |
| 1319 | const int head0 = zt * ncols2; |
| 1320 | |
| 1321 | const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0); |
| 1322 | const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio)); |
| 1323 | const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr : |
| 1324 | (const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1); |
| 1325 | float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head0) * (DV/2); |
| 1326 | |
| 1327 | const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio)); |
| 1328 | const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr; |
| 1329 | |
| 1330 | const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f; |
| 1331 | |
| 1332 | const int kb0_start_kernel = kb0_start * kb_niter; |
| 1333 | int kb0_stop_kernel = kb0_stop * kb_niter; |
| 1334 | |
| 1335 | if (KV_max) { |
| 1336 | kb0_stop_kernel = min(kb0_stop_kernel, KV_max[sequence*iter_j + jt] / c::nbatch_fa); |
| 1337 | } |
| 1338 | |
| 1339 | constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer. |
| 1340 | if (kb0_start == 0) { |
| 1341 | constexpr bool needs_fixup = false; // CUDA block is working on an entire tile. |
| 1342 | flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup> |
| 1343 | (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, |
| 1344 | ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); |
| 1345 | } else { |
| 1346 | constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile. |
| 1347 | flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup> |
| 1348 | (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, |
| 1349 | ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); |
| 1350 | } |
| 1351 | |
| 1352 | kbc += iter_k; |
| 1353 | kbc -= kbc % iter_k; |
| 1354 | |
| 1355 | kb0_start = 0; |
| 1356 | kb0_stop = min(iter_k, kbc_stop - kbc); |
| 1357 | } |
| 1358 | |
| 1359 | if (kbc >= kbc_stop) { |
| 1360 | return; |
| 1361 | } |
| 1362 | |
| 1363 | const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2)); |
| 1364 | const int zt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); // head in units of ncols2 |
| 1365 | const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile. |
| 1366 | |
| 1367 | const int head0 = zt * ncols2; |
| 1368 | |
| 1369 | const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0); |
| 1370 | const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio)); |
| 1371 | const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr : |
| 1372 | (const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1); |
| 1373 | float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head0) * (DV/2); |
| 1374 | |
| 1375 | const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio)); |
| 1376 | const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr; |
| 1377 | |
| 1378 | const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f; |
| 1379 | |
| 1380 | const int kb0_start_kernel = kb0_start * kb_niter; |
| 1381 | int kb0_stop_kernel = kb0_stop * kb_niter; |
| 1382 | |
| 1383 | if (KV_max) { |
| 1384 | kb0_stop_kernel = min(kb0_stop_kernel, KV_max[sequence*iter_j + jt] / c::nbatch_fa); |
| 1385 | } |
| 1386 | |
| 1387 | constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks. |
| 1388 | constexpr bool needs_fixup = false; |
| 1389 | flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup> |
| 1390 | (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, |
| 1391 | ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); |
| 1392 | #else |
| 1393 | GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, |
| 1394 | max_bias, m0, m1, n_head_log2, logit_softcap, |
| 1395 | ne00, ne01, ne02, ne03, |
| 1396 | nb01, nb02, nb03, |
| 1397 | ne10, ne11, ne12, ne13, |
| 1398 | nb11, nb12, nb13, |
| 1399 | nb21, nb22, nb23, |
| 1400 | ne31, ne32, ne33, |
| 1401 | nb31, nb32, nb33); |
| 1402 | NO_DEVICE_CODE; |
| 1403 | #endif // defined(FLASH_ATTN_AVAILABLE) && defined(TURING_MMA_AVAILABLE) |
| 1404 | } |
| 1405 | |
| 1406 | template <int DKQ, int DV, int ncols1, int ncols2> |
| 1407 | void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
| 1408 | const ggml_tensor * KQV = dst; |
| 1409 | const int id = ggml_cuda_get_device(); |
| 1410 | const int cc = ggml_cuda_info().devices[id].cc; |
| 1411 | |
| 1412 | typedef fattn_mma_f16_config<DKQ, DV> c; |
| 1413 | |
| 1414 | const int nstages = cp_async_available(cc) ? c::nstages_target : 0; |
| 1415 | |
| 1416 | constexpr int ncols = ncols1 * ncols2; |
| 1417 | constexpr int ntiles = ncols <= 8 ? 1 : 2; // Number of tiles per warp. |
| 1418 | constexpr int cols_per_warp = ntiles * tile_B::I; |
| 1419 | constexpr int nwarps_max_x = ncols / cols_per_warp; |
| 1420 | constexpr int nwarps_max_y = c::nbatch_fa / tile_A::I; |
| 1421 | constexpr int nwarps = nwarps_max_x*nwarps_max_y <= c::nwarps_max ? nwarps_max_x*nwarps_max_y : c::nwarps_max; |
| 1422 | |
| 1423 | constexpr bool mla = DKQ == 576; |
| 1424 | |
| 1425 | const int nbatch_K2 = c::get_nbatch_K2_host (cc, ncols); |
| 1426 | const int nbatch_V2 = c::get_nbatch_K2_host (cc, ncols); |
| 1427 | const int nbatch_combine = c::get_nbatch_combine_host(cc, ncols); |
| 1428 | |
| 1429 | static_assert(DKQ % tile_B::J == 0, "bad DKQ" ); |
| 1430 | static_assert(DV % tile_A::J == 0, "bad DV" ); |
| 1431 | static_assert(ncols % cols_per_warp == 0, "bad ncols" ); |
| 1432 | |
| 1433 | const size_t nbytes_shared_KV_1stage = c::nbatch_fa * std::max(a: nbatch_K2 + 4, b: nbatch_V2 + 4) * sizeof(half2); |
| 1434 | const size_t nbytes_shared_KV_2stage = c::nbatch_fa * (nbatch_K2 + 4 + nbatch_V2 + 4) * sizeof(half2); |
| 1435 | const size_t nbytes_shared_Q = ncols * (DKQ/2 + 4) * sizeof(half2); |
| 1436 | const size_t nbytes_shared_mask = ncols1 * (c::nbatch_fa/2 + 4) * sizeof(half2); |
| 1437 | const size_t nbytes_shared_combine = nwarps*cols_per_warp * (nbatch_combine + 4) * sizeof(half2); |
| 1438 | |
| 1439 | const size_t nbytes_shared_KV = nstages <= 1 ? nbytes_shared_KV_1stage : nbytes_shared_KV_2stage; |
| 1440 | |
| 1441 | const size_t nbytes_shared_total = std::max(nbytes_shared_combine, c::Q_in_reg ? |
| 1442 | std::max(a: nbytes_shared_Q, b: nbytes_shared_KV + nbytes_shared_mask) : |
| 1443 | nbytes_shared_Q + nbytes_shared_KV + nbytes_shared_mask); |
| 1444 | |
| 1445 | float logit_softcap; |
| 1446 | memcpy(dest: &logit_softcap, src: (const float *) KQV->op_params + 2, n: sizeof(float)); |
| 1447 | |
| 1448 | fattn_kernel_t fattn_kernel; |
| 1449 | if (logit_softcap == 0.0f) { |
| 1450 | constexpr bool use_logit_softcap = false; |
| 1451 | fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla>; |
| 1452 | |
| 1453 | #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) |
| 1454 | static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; |
| 1455 | if (!shared_memory_limit_raised[id]) { |
| 1456 | CUDA_CHECK(cudaFuncSetAttribute(fattn_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total)); |
| 1457 | shared_memory_limit_raised[id] = true; |
| 1458 | } |
| 1459 | #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) |
| 1460 | } else { |
| 1461 | constexpr bool use_logit_softcap = true; |
| 1462 | fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla>; |
| 1463 | |
| 1464 | #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) |
| 1465 | static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; |
| 1466 | if (!shared_memory_limit_raised[id]) { |
| 1467 | CUDA_CHECK(cudaFuncSetAttribute(fattn_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total)); |
| 1468 | shared_memory_limit_raised[id] = true; |
| 1469 | } |
| 1470 | #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) |
| 1471 | } |
| 1472 | |
| 1473 | launch_fattn<DV, ncols1, ncols2> |
| 1474 | (ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, FATTN_KQ_STRIDE, true, true, true); |
| 1475 | } |
| 1476 | |
| 1477 | |
| 1478 | #define DECL_FATTN_MMA_F16_CASE(DKQ, DV, ncols1, ncols2) \ |
| 1479 | template void ggml_cuda_flash_attn_ext_mma_f16_case \ |
| 1480 | <DKQ, DV, ncols1, ncols2>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \ |
| 1481 | |
| 1482 | #define DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(DKQ, DV, ncols) \ |
| 1483 | extern DECL_FATTN_MMA_F16_CASE(DKQ, DV, (ncols)/ 1, 1); \ |
| 1484 | extern DECL_FATTN_MMA_F16_CASE(DKQ, DV, (ncols)/ 2, 2); \ |
| 1485 | extern DECL_FATTN_MMA_F16_CASE(DKQ, DV, (ncols)/ 4, 4); \ |
| 1486 | extern DECL_FATTN_MMA_F16_CASE(DKQ, DV, (ncols)/ 8, 8); \ |
| 1487 | extern DECL_FATTN_MMA_F16_CASE(DKQ, DV, (ncols)/16, 16); \ |
| 1488 | |
| 1489 | DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 64, 8) |
| 1490 | DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 80, 8) |
| 1491 | DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 96, 8) |
| 1492 | DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 112, 8) |
| 1493 | DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128, 8) |
| 1494 | DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 8) |
| 1495 | |
| 1496 | DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 64, 16) |
| 1497 | DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 80, 16) |
| 1498 | DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 96, 16) |
| 1499 | DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 112, 16) |
| 1500 | DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128, 16) |
| 1501 | DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 16) |
| 1502 | |
| 1503 | DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 64, 32) |
| 1504 | DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 80, 32) |
| 1505 | DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 96, 32) |
| 1506 | DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 112, 32) |
| 1507 | DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128, 32) |
| 1508 | DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 32) |
| 1509 | |
| 1510 | DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 64, 64) |
| 1511 | DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 80, 64) |
| 1512 | DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 96, 64) |
| 1513 | DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 112, 64) |
| 1514 | DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128, 64) |
| 1515 | DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 64) |
| 1516 | |
| 1517 | // The number of viable configurations for Deepseek is very limited: |
| 1518 | extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16); |
| 1519 | extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16); |
| 1520 | extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16); |
| 1521 | |