| 1 | #include "common.cuh" |
| 2 | #include "fattn-common.cuh" |
| 3 | |
| 4 | static int ggml_cuda_fattn_vec_get_nthreads_host(const int cc) { |
| 5 | return 128; |
| 6 | GGML_UNUSED(cc); |
| 7 | } |
| 8 | |
| 9 | static constexpr __device__ int ggml_cuda_fattn_vec_get_nthreads_device() { |
| 10 | return 128; |
| 11 | } |
| 12 | |
| 13 | // Currenlty llvm with the amdgcn target dose not support unrolling loops |
| 14 | // that contain a break that can not be resolved at compile time. |
| 15 | #ifdef __clang__ |
| 16 | #pragma clang diagnostic push |
| 17 | #pragma clang diagnostic ignored "-Wpass-failed" |
| 18 | #endif // __clang__ |
| 19 | template<int D, int ncols, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size |
| 20 | __launch_bounds__(ggml_cuda_fattn_vec_get_nthreads_device(), 1) |
| 21 | static __global__ void flash_attn_ext_vec( |
| 22 | const char * __restrict__ Q, |
| 23 | const char * __restrict__ K, |
| 24 | const char * __restrict__ V, |
| 25 | const char * __restrict__ mask, |
| 26 | const char * __restrict__ sinks, |
| 27 | const int * __restrict__ KV_max, |
| 28 | float * __restrict__ dst, |
| 29 | float2 * __restrict__ dst_meta, |
| 30 | const float scale, |
| 31 | const float max_bias, |
| 32 | const float m0, |
| 33 | const float m1, |
| 34 | const uint32_t n_head_log2, |
| 35 | const float logit_softcap, |
| 36 | const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03, |
| 37 | const int32_t nb01, const int32_t nb02, const int32_t nb03, |
| 38 | const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, |
| 39 | const int32_t nb11, const int32_t nb12, const int64_t nb13, |
| 40 | const int32_t nb21, const int32_t nb22, const int64_t nb23, |
| 41 | const int32_t ne31, const int32_t ne32, const int32_t ne33, |
| 42 | const int32_t nb31, const int32_t nb32, const int64_t nb33) { |
| 43 | #ifdef FLASH_ATTN_AVAILABLE |
| 44 | |
| 45 | // Skip unused kernel variants for faster compilation: |
| 46 | if (use_logit_softcap && !(D == 128 || D == 256)) { |
| 47 | GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, |
| 48 | max_bias, m0, m1, n_head_log2, logit_softcap, |
| 49 | ne00, ne01, ne02, ne03, |
| 50 | nb01, nb02, nb03, |
| 51 | ne10, ne11, ne12, ne13, |
| 52 | nb11, nb12, nb13, |
| 53 | nb21, nb22, nb23, |
| 54 | ne31, ne32, ne33, |
| 55 | nb31, nb32, nb33); |
| 56 | NO_DEVICE_CODE; |
| 57 | return; |
| 58 | } |
| 59 | |
| 60 | //In this kernel Q, K, V are matrices while i, j, k are matrix indices. |
| 61 | |
| 62 | constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes(); |
| 63 | constexpr int cpy_ne = cpy_nb / 4; |
| 64 | |
| 65 | #ifdef GGML_USE_HIP |
| 66 | #ifdef RDNA |
| 67 | constexpr int nthreads_KQ_q = 2; |
| 68 | #else |
| 69 | constexpr int nthreads_KQ_q = 4; |
| 70 | #endif // RDNA |
| 71 | constexpr int nthreads_V_q = (D/4 < 32 ? D/4 : 32); |
| 72 | #else |
| 73 | constexpr int nthreads_KQ_q = (D/4 < 32 ? D/4 : 32); |
| 74 | constexpr int nthreads_V_q = (D/4 < 32 ? D/4 : 32); |
| 75 | #endif // GGML_USE_HIP |
| 76 | |
| 77 | constexpr int nthreads = ggml_cuda_fattn_vec_get_nthreads_device(); |
| 78 | constexpr int nthreads_KQ = type_K == GGML_TYPE_F16 ? 128 / cpy_nb : nthreads_KQ_q; |
| 79 | constexpr int nthreads_V = type_V == GGML_TYPE_F16 ? 128 / cpy_nb : nthreads_V_q; |
| 80 | |
| 81 | static_assert(WARP_SIZE % nthreads_KQ == 0, "bad nthreads_K" ); |
| 82 | static_assert(WARP_SIZE % nthreads_V == 0, "bad nthreads_V" ); |
| 83 | |
| 84 | constexpr int V_rows_per_thread = type_V == GGML_TYPE_F16 ? 2*cpy_ne : 4; |
| 85 | constexpr int V_cols_per_iter = WARP_SIZE / nthreads_V; |
| 86 | |
| 87 | constexpr vec_dot_KQ_t vec_dot_KQ = get_vec_dot_KQ<type_K, D, nthreads_KQ>(); |
| 88 | constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16; |
| 89 | #ifdef FAST_FP16_AVAILABLE |
| 90 | constexpr dequantize_V_t dequantize_V = get_dequantize_V<type_V, half, V_rows_per_thread>(); |
| 91 | #else |
| 92 | constexpr dequantize_V_t dequantize_V = get_dequantize_V<type_V, float, V_rows_per_thread>(); |
| 93 | #endif // FAST_FP16_AVAILABLE |
| 94 | |
| 95 | const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on. |
| 96 | |
| 97 | const int sequence = blockIdx.z / ne02; |
| 98 | const int head = blockIdx.z - sequence*ne02; |
| 99 | const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. |
| 100 | Q += nb03*sequence + nb02* head + nb01*ic0; |
| 101 | K += nb13*sequence + nb12*(head / gqa_ratio); |
| 102 | V += nb23*sequence + nb22*(head / gqa_ratio); |
| 103 | |
| 104 | const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0); |
| 105 | |
| 106 | const float slope = get_alibi_slope(max_bias, h: head, n_head_log2, m0, m1); |
| 107 | |
| 108 | static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64." ); |
| 109 | constexpr int nwarps = nthreads / WARP_SIZE; |
| 110 | const int tid = WARP_SIZE*threadIdx.y + threadIdx.x; |
| 111 | __builtin_assume(tid < nthreads); |
| 112 | |
| 113 | constexpr int ne_KQ = ncols*D; |
| 114 | constexpr int ne_combine = nwarps*V_cols_per_iter*D; |
| 115 | #ifdef FAST_FP16_AVAILABLE |
| 116 | half2 VKQ[ncols][(D/2)/nthreads_V] = {{{0.0f, 0.0f}}}; |
| 117 | __shared__ half KQ[ne_KQ > ne_combine ? ne_KQ : ne_combine]; |
| 118 | #else |
| 119 | float2 VKQ[ncols][(D/2)/nthreads_V] = {{{0.0f, 0.0f}}}; |
| 120 | __shared__ float KQ[ne_KQ > ne_combine ? ne_KQ : ne_combine]; |
| 121 | #endif // FAST_FP16_AVAILABLE |
| 122 | |
| 123 | float KQ_max[ncols]; |
| 124 | float KQ_sum[ncols]; |
| 125 | #pragma unroll |
| 126 | for (int j = 0; j < ncols; ++j) { |
| 127 | KQ_max[j] = -FLT_MAX/2.0f; |
| 128 | KQ_sum[j] = 0.0f; |
| 129 | } |
| 130 | |
| 131 | // Convert Q to float2 (f16 K) or q8_1 (quantized K) and store in registers: |
| 132 | #ifdef FAST_FP16_AVAILABLE |
| 133 | half2 Q_reg[ncols][(D/2)/nthreads_KQ]; // Will be initialized completely. |
| 134 | #else |
| 135 | float2 Q_reg[ncols][(D/2)/nthreads_KQ] = {{{0.0f, 0.0f}}}; // May be only partially initialized. |
| 136 | #endif // FAST_FP16_AVAILABLE |
| 137 | int Q_i32[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)]; |
| 138 | float2 Q_ds[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)]; |
| 139 | if constexpr (Q_q8_1) { |
| 140 | #pragma unroll |
| 141 | for (int j0 = 0; j0 < ncols; j0 += nwarps) { |
| 142 | const int j = j0 + threadIdx.y; |
| 143 | |
| 144 | if (j0 + nwarps > ncols && j >= ncols) { |
| 145 | break; |
| 146 | } |
| 147 | |
| 148 | // Reuse KQ as temporary storage for converting Q to q8_1: |
| 149 | int * tmp_q_i32 = (int *) &KQ[j*D]; |
| 150 | float2 * tmp_q_ds = (float2 *) (tmp_q_i32 + D/sizeof(int)); |
| 151 | |
| 152 | // Set memory to zero if out of bounds: |
| 153 | if (ncols > 1 && ic0 + j >= ne01) { |
| 154 | #pragma unroll |
| 155 | for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += WARP_SIZE) { |
| 156 | const int i = i0 + threadIdx.x; |
| 157 | |
| 158 | if (i0 + WARP_SIZE <= D/sizeof(int) || i < D/sizeof(int)) { |
| 159 | tmp_q_i32[i] = 0; |
| 160 | } |
| 161 | } |
| 162 | if (threadIdx.x < D/QK8_1) { |
| 163 | tmp_q_ds[threadIdx.x] = make_float2(x: 0.0f, y: 0.0f); |
| 164 | } |
| 165 | } else { |
| 166 | const float * Q_f = (const float *) (Q + j*nb01); |
| 167 | constexpr int nthreads_quantize = D/sizeof(int) < WARP_SIZE ? D/sizeof(int) : WARP_SIZE; |
| 168 | #pragma unroll |
| 169 | for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += nthreads_quantize) { |
| 170 | quantize_q8_1_to_shared<float2, nthreads_quantize> |
| 171 | (Q_f + i0*sizeof(int), scale, tmp_q_i32 + i0, tmp_q_ds + i0/QI8_1); |
| 172 | } |
| 173 | } |
| 174 | } |
| 175 | |
| 176 | __syncthreads(); |
| 177 | |
| 178 | #pragma unroll |
| 179 | for (int j = 0; j < ncols; ++j) { |
| 180 | int * tmp_q_i32 = (int *) &KQ[j*D]; |
| 181 | float2 * tmp_q_ds = (float2 *) (tmp_q_i32 + D/sizeof(int)); |
| 182 | |
| 183 | #pragma unroll |
| 184 | for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += nthreads_KQ) { |
| 185 | const int i = i0 + (nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ); |
| 186 | |
| 187 | Q_i32[j][i0/nthreads_KQ] = tmp_q_i32[i]; |
| 188 | Q_ds[j][i0/nthreads_KQ] = tmp_q_ds[i/QI8_1]; |
| 189 | } |
| 190 | } |
| 191 | |
| 192 | __syncthreads(); |
| 193 | } else { |
| 194 | #ifdef FAST_FP16_AVAILABLE |
| 195 | const half2 scale_h2 = make_half2(scale, scale); |
| 196 | #pragma unroll |
| 197 | for (int j = 0; j < ncols; ++j) { |
| 198 | const float2 * Q_j = (const float2 *) (Q + j*nb01); |
| 199 | #pragma unroll |
| 200 | for (int i0 = 0; i0 < D/2; i0 += nthreads_KQ*cpy_ne) { |
| 201 | const int i = i0 + (nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ)*cpy_ne; |
| 202 | |
| 203 | float2 tmp[cpy_ne] = {{0.0f, 0.0f}}; |
| 204 | if (ncols == 1 || ic0 + j < ne01) { |
| 205 | ggml_cuda_memcpy_1<cpy_nb>(tmp, &Q_j[i]); |
| 206 | ggml_cuda_memcpy_1<cpy_nb>(tmp + cpy_ne/2, &Q_j[i + cpy_ne/2]); |
| 207 | } |
| 208 | #pragma unroll |
| 209 | for (int i1 = 0; i1 < cpy_ne; ++i1) { |
| 210 | Q_reg[j][i0/nthreads_KQ + i1] = make_half2(tmp[i1].x, tmp[i1].y); |
| 211 | } |
| 212 | } |
| 213 | #pragma unroll |
| 214 | for (int k = 0; k < (D/2)/nthreads_KQ; ++k) { |
| 215 | Q_reg[j][k] *= scale_h2; |
| 216 | } |
| 217 | } |
| 218 | #else |
| 219 | #pragma unroll |
| 220 | for (int j = 0; j < ncols; ++j) { |
| 221 | const float2 * Q_j = (const float2 *) (Q + j*nb01); |
| 222 | #pragma unroll |
| 223 | for (int i0 = 0; i0 < D/2; i0 += nthreads_KQ*cpy_ne) { |
| 224 | const int i = i0 + (nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ)*cpy_ne; |
| 225 | if (ncols == 1 || ic0 + j < ne01) { |
| 226 | ggml_cuda_memcpy_1<cpy_nb>(&Q_reg[j][i0/nthreads_KQ], &Q_j[i]); |
| 227 | ggml_cuda_memcpy_1<cpy_nb>(&Q_reg[j][i0/nthreads_KQ + cpy_ne/2], &Q_j[i + cpy_ne/2]); |
| 228 | } |
| 229 | } |
| 230 | #pragma unroll |
| 231 | for (int k = 0; k < (D/2)/nthreads_KQ; ++k) { |
| 232 | Q_reg[j][k].x *= scale; |
| 233 | Q_reg[j][k].y *= scale; |
| 234 | } |
| 235 | } |
| 236 | #endif // FAST_FP16_AVAILABLE |
| 237 | } |
| 238 | |
| 239 | const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11; |
| 240 | K += blockIdx.y*nthreads * nb11; |
| 241 | V += blockIdx.y*nthreads * nb21; |
| 242 | maskh += blockIdx.y*nthreads; |
| 243 | for (int k_VKQ_0 = blockIdx.y*nthreads; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*nthreads, |
| 244 | // Increment pointers after each loop: |
| 245 | K += gridDim.y*nthreads*nb11, V += gridDim.y*nthreads*nb21, maskh += gridDim.y*nthreads) { |
| 246 | |
| 247 | // Calculate KQ tile and keep track of new maximum KQ values: |
| 248 | float KQ_reg[ncols]; // KQ in registers. |
| 249 | |
| 250 | float KQ_max_new[ncols]; |
| 251 | #pragma unroll |
| 252 | for (int j = 0; j < ncols; ++j) { |
| 253 | KQ_max_new[j] = KQ_max[j]; |
| 254 | } |
| 255 | |
| 256 | #pragma unroll |
| 257 | for (int i_KQ_0 = 0; i_KQ_0 < nthreads_KQ; ++i_KQ_0) { |
| 258 | const int i_KQ = threadIdx.y*WARP_SIZE + (nthreads_KQ == WARP_SIZE ? 0 : (threadIdx.x & ~(nthreads_KQ-1))) + i_KQ_0; |
| 259 | |
| 260 | #pragma unroll |
| 261 | for (int j = 0; j < ncols; ++j) { |
| 262 | float sum = vec_dot_KQ(K + i_KQ*nb11, Q_reg[j], Q_i32[j], Q_ds[j]); |
| 263 | sum = warp_reduce_sum<nthreads_KQ>(sum); |
| 264 | |
| 265 | if (use_logit_softcap) { |
| 266 | sum = logit_softcap*tanhf(a: sum); |
| 267 | } |
| 268 | |
| 269 | if (mask) { |
| 270 | sum += slope*__half2float(a: maskh[j*ne11 + i_KQ]); |
| 271 | } |
| 272 | |
| 273 | KQ_max_new[j] = fmaxf(KQ_max_new[j], sum); |
| 274 | |
| 275 | if ((nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ) == i_KQ_0) { |
| 276 | KQ_reg[j] = sum; |
| 277 | } |
| 278 | } |
| 279 | } |
| 280 | |
| 281 | #pragma unroll |
| 282 | for (int j = 0; j < ncols; ++j) { |
| 283 | #pragma unroll |
| 284 | for (int offset = nthreads_KQ; offset < WARP_SIZE; offset <<= 1) { |
| 285 | KQ_max_new[j] = fmaxf(KQ_max_new[j], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[j], offset, WARP_SIZE)); |
| 286 | } |
| 287 | const float KQ_max_scale = expf(KQ_max[j] - KQ_max_new[j]); |
| 288 | KQ_max[j] = KQ_max_new[j]; |
| 289 | |
| 290 | KQ_reg[j] = expf(KQ_reg[j] - KQ_max[j]); |
| 291 | KQ_sum[j] = KQ_sum[j]*KQ_max_scale + KQ_reg[j]; |
| 292 | KQ[j*nthreads + tid] = KQ_reg[j]; |
| 293 | |
| 294 | #ifdef FAST_FP16_AVAILABLE |
| 295 | const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale); |
| 296 | #pragma unroll |
| 297 | for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) { |
| 298 | VKQ[j][i_VKQ_0/nthreads_V] *= KQ_max_scale_h2; |
| 299 | } |
| 300 | #else |
| 301 | #pragma unroll |
| 302 | for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) { |
| 303 | VKQ[j][i_VKQ_0/nthreads_V].x *= KQ_max_scale; |
| 304 | VKQ[j][i_VKQ_0/nthreads_V].y *= KQ_max_scale; |
| 305 | } |
| 306 | #endif // FAST_FP16_AVAILABLE |
| 307 | } |
| 308 | |
| 309 | #ifndef GGML_USE_HIP |
| 310 | __syncwarp(); |
| 311 | #endif // GGML_USE_HIP |
| 312 | |
| 313 | #pragma unroll |
| 314 | for (int k0 = 0; k0 < WARP_SIZE; k0 += V_cols_per_iter) { |
| 315 | const int k = threadIdx.y*WARP_SIZE + k0 + (nthreads_V == WARP_SIZE ? 0 : threadIdx.x / nthreads_V); |
| 316 | |
| 317 | #ifdef FAST_FP16_AVAILABLE |
| 318 | half2 KQ_k[ncols]; |
| 319 | #pragma unroll |
| 320 | for (int j = 0; j < ncols; ++j) { |
| 321 | KQ_k[j] = __half2half2(KQ[j*nthreads + k]); |
| 322 | } |
| 323 | #pragma unroll |
| 324 | for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) { |
| 325 | half2 tmp[V_rows_per_thread/2]; |
| 326 | dequantize_V(V + k*nb21, tmp, |
| 327 | 2*i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*V_rows_per_thread); |
| 328 | #pragma unroll |
| 329 | for (int i_VKQ_1 = 0; i_VKQ_1 < V_rows_per_thread/2; ++i_VKQ_1) { |
| 330 | #pragma unroll |
| 331 | for (int j = 0; j < ncols; ++j) { |
| 332 | VKQ[j][i_VKQ_0/nthreads_V + i_VKQ_1] += tmp[i_VKQ_1]*KQ_k[j]; |
| 333 | } |
| 334 | } |
| 335 | } |
| 336 | #else |
| 337 | float KQ_k[ncols]; |
| 338 | #pragma unroll |
| 339 | for (int j = 0; j < ncols; ++j) { |
| 340 | KQ_k[j] = KQ[j*nthreads + k]; |
| 341 | } |
| 342 | #pragma unroll |
| 343 | for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) { |
| 344 | float2 tmp[V_rows_per_thread/2]; |
| 345 | dequantize_V(V + k*nb21, tmp, |
| 346 | 2*i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*V_rows_per_thread); |
| 347 | #pragma unroll |
| 348 | for (int i_VKQ_1 = 0; i_VKQ_1 < V_rows_per_thread/2; ++i_VKQ_1) { |
| 349 | #pragma unroll |
| 350 | for (int j = 0; j < ncols; ++j) { |
| 351 | VKQ[j][i_VKQ_0/nthreads_V + i_VKQ_1].x += tmp[i_VKQ_1].x*KQ_k[j]; |
| 352 | VKQ[j][i_VKQ_0/nthreads_V + i_VKQ_1].y += tmp[i_VKQ_1].y*KQ_k[j]; |
| 353 | } |
| 354 | } |
| 355 | } |
| 356 | #endif // FAST_FP16_AVAILABLE |
| 357 | } |
| 358 | } |
| 359 | |
| 360 | if (sinks && blockIdx.y == 0) { |
| 361 | const float sink = ((const float *) sinks)[head]; |
| 362 | |
| 363 | #pragma unroll |
| 364 | for (int j0 = 0; j0 < ncols; j0 += nwarps) { |
| 365 | const int j = j0 + threadIdx.y; |
| 366 | |
| 367 | if (j0 + nwarps > ncols && j >= ncols) { |
| 368 | break; |
| 369 | } |
| 370 | |
| 371 | const float kqmax_new_j = fmaxf(sink, KQ_max[j]); |
| 372 | const float KQ_max_scale = expf(KQ_max[j] - kqmax_new_j); |
| 373 | KQ_max[j] = kqmax_new_j; |
| 374 | |
| 375 | KQ_sum[j] = KQ_sum[j]*KQ_max_scale + (threadIdx.x == 0 ? expf(sink - KQ_max[j]) : 0.0f); |
| 376 | |
| 377 | #ifdef FAST_FP16_AVAILABLE |
| 378 | const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale); |
| 379 | #pragma unroll |
| 380 | for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) { |
| 381 | VKQ[j][i_VKQ_0/nthreads_V] *= KQ_max_scale_h2; |
| 382 | } |
| 383 | #else |
| 384 | #pragma unroll |
| 385 | for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) { |
| 386 | VKQ[j][i_VKQ_0/nthreads_V].x *= KQ_max_scale; |
| 387 | VKQ[j][i_VKQ_0/nthreads_V].y *= KQ_max_scale; |
| 388 | } |
| 389 | #endif // FAST_FP16_AVAILABLE |
| 390 | } |
| 391 | } |
| 392 | |
| 393 | __shared__ float KQ_max_shared[ncols][WARP_SIZE]; |
| 394 | __shared__ float KQ_sum_shared[ncols][WARP_SIZE]; |
| 395 | #pragma unroll |
| 396 | for (int j = 0; j < ncols; ++j) { |
| 397 | if (threadIdx.y == 0) { |
| 398 | KQ_max_shared[j][threadIdx.x] = -FLT_MAX/2.0f; |
| 399 | KQ_sum_shared[j][threadIdx.x] = 0.0f; |
| 400 | } |
| 401 | } |
| 402 | |
| 403 | __syncthreads(); |
| 404 | |
| 405 | #pragma unroll |
| 406 | for (int j = 0; j < ncols; ++j) { |
| 407 | if (threadIdx.x == 0) { |
| 408 | KQ_max_shared[j][threadIdx.y] = KQ_max[j]; |
| 409 | } |
| 410 | } |
| 411 | __syncthreads(); |
| 412 | |
| 413 | #pragma unroll |
| 414 | for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) { |
| 415 | if (ncols > 1 && ic0 + j_VKQ >= ne01) { |
| 416 | break; |
| 417 | } |
| 418 | |
| 419 | float kqmax_new = KQ_max_shared[j_VKQ][threadIdx.x]; |
| 420 | kqmax_new = warp_reduce_max(x: kqmax_new); |
| 421 | const float kqmax_scale = expf(KQ_max[j_VKQ] - kqmax_new); |
| 422 | KQ_max[j_VKQ] = kqmax_new; |
| 423 | |
| 424 | #ifdef FAST_FP16_AVAILABLE |
| 425 | half2 * VKQ_tmp = (half2 *) KQ + threadIdx.y*(V_cols_per_iter*D/2) |
| 426 | + (nthreads_V == WARP_SIZE ? 0 : threadIdx.x / nthreads_V)*(D/2); |
| 427 | |
| 428 | const half2 kqmax_scale_h2 = make_half2(kqmax_scale, kqmax_scale); |
| 429 | #pragma unroll |
| 430 | for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) { |
| 431 | VKQ[j_VKQ][i_VKQ_0/nthreads_V] *= kqmax_scale_h2; |
| 432 | } |
| 433 | #pragma unroll |
| 434 | for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) { |
| 435 | const int i_VKQ = i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*(V_rows_per_thread/2); |
| 436 | |
| 437 | ggml_cuda_memcpy_1<V_rows_per_thread*sizeof(half)>(VKQ_tmp + i_VKQ, &VKQ[j_VKQ][i_VKQ_0/nthreads_V]); |
| 438 | } |
| 439 | #else |
| 440 | float2 * VKQ_tmp = (float2 *) KQ + threadIdx.y*(V_cols_per_iter*D/2) |
| 441 | + (nthreads_V == WARP_SIZE ? 0 : threadIdx.x / nthreads_V)*(D/2); |
| 442 | |
| 443 | #pragma unroll |
| 444 | for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) { |
| 445 | VKQ[j_VKQ][i_VKQ_0/nthreads_V].x *= kqmax_scale; |
| 446 | VKQ[j_VKQ][i_VKQ_0/nthreads_V].y *= kqmax_scale; |
| 447 | } |
| 448 | #pragma unroll |
| 449 | for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) { |
| 450 | const int i_VKQ = i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*(V_rows_per_thread/2); |
| 451 | |
| 452 | ggml_cuda_memcpy_1<V_rows_per_thread/2*sizeof(float)>(VKQ_tmp + i_VKQ, &VKQ[j_VKQ][i_VKQ_0/nthreads_V]); |
| 453 | ggml_cuda_memcpy_1<V_rows_per_thread/2*sizeof(float)>(VKQ_tmp + i_VKQ + V_rows_per_thread/4, &VKQ[j_VKQ][i_VKQ_0/nthreads_V + V_rows_per_thread/4]); |
| 454 | } |
| 455 | #endif // FAST_FP16_AVAILABLE |
| 456 | |
| 457 | KQ_sum[j_VKQ] *= kqmax_scale; |
| 458 | KQ_sum[j_VKQ] = warp_reduce_sum(KQ_sum[j_VKQ]); |
| 459 | if (threadIdx.x == 0) { |
| 460 | KQ_sum_shared[j_VKQ][threadIdx.y] = KQ_sum[j_VKQ]; |
| 461 | } |
| 462 | |
| 463 | __syncthreads(); |
| 464 | |
| 465 | if (nthreads <= D || tid < D) { |
| 466 | KQ_sum[j_VKQ] = KQ_sum_shared[j_VKQ][threadIdx.x]; |
| 467 | KQ_sum[j_VKQ] = warp_reduce_sum(KQ_sum[j_VKQ]); |
| 468 | |
| 469 | #pragma unroll |
| 470 | for (int i0 = 0; i0 < D; i0 += nthreads) { |
| 471 | float dst_val = 0; |
| 472 | #pragma unroll |
| 473 | for (int w = 0; w < nwarps; ++w) { |
| 474 | #pragma unroll |
| 475 | for (int v = 0; v < V_cols_per_iter; ++v) { |
| 476 | dst_val += float(KQ[w*V_cols_per_iter*D + v*D + i0 + tid]); |
| 477 | } |
| 478 | } |
| 479 | if (gridDim.y == 1) { |
| 480 | dst_val /= KQ_sum[j_VKQ]; |
| 481 | } |
| 482 | dst[(((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*D + i0 + tid] = dst_val; |
| 483 | } |
| 484 | } |
| 485 | |
| 486 | if (j_VKQ < ncols-1) { |
| 487 | __syncthreads(); |
| 488 | } |
| 489 | |
| 490 | } |
| 491 | |
| 492 | if (gridDim.y != 1 && tid < ncols && (ncols == 1 || ic0 + tid < ne01)) { |
| 493 | dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(KQ_max[tid], KQ_sum[tid]); |
| 494 | } |
| 495 | #else |
| 496 | GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, |
| 497 | max_bias, m0, m1, n_head_log2, logit_softcap, |
| 498 | ne00, ne01, ne02, ne03, |
| 499 | nb01, nb02, nb03, |
| 500 | ne10, ne11, ne12, ne13, |
| 501 | nb11, nb12, nb13, |
| 502 | nb21, nb22, nb23, |
| 503 | ne31, ne32, ne33, |
| 504 | nb31, nb32, nb33); |
| 505 | NO_DEVICE_CODE; |
| 506 | #endif // FLASH_ATTN_AVAILABLE |
| 507 | } |
| 508 | #ifdef __clang__ |
| 509 | #pragma clang diagnostic pop |
| 510 | #endif // __clang__ |
| 511 | |
| 512 | template <int D, int cols_per_block, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> |
| 513 | void ggml_cuda_flash_attn_ext_vec_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
| 514 | const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; |
| 515 | |
| 516 | const int nthreads = ggml_cuda_fattn_vec_get_nthreads_host(cc); |
| 517 | const int nwarps = nthreads / WARP_SIZE; |
| 518 | fattn_kernel_t fattn_kernel = flash_attn_ext_vec<D, cols_per_block, type_K, type_V, use_logit_softcap>; |
| 519 | const bool need_f16_K = type_K == GGML_TYPE_F16; |
| 520 | const bool need_f16_V = type_V == GGML_TYPE_F16; |
| 521 | constexpr size_t nbytes_shared = 0; |
| 522 | launch_fattn<D, cols_per_block, 1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false); |
| 523 | } |
| 524 | |
| 525 | template <int D, ggml_type type_K, ggml_type type_V> |
| 526 | void ggml_cuda_flash_attn_ext_vec_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
| 527 | const ggml_tensor * KQV = dst; |
| 528 | const ggml_tensor * Q = dst->src[0]; |
| 529 | |
| 530 | float logit_softcap; |
| 531 | memcpy(dest: &logit_softcap, src: (const float *) KQV->op_params + 2, n: sizeof(float)); |
| 532 | |
| 533 | if (Q->ne[1] == 1) { |
| 534 | constexpr int cols_per_block = 1; |
| 535 | if (logit_softcap == 0.0f) { |
| 536 | constexpr bool use_logit_softcap = false; |
| 537 | ggml_cuda_flash_attn_ext_vec_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst); |
| 538 | } else { |
| 539 | constexpr bool use_logit_softcap = true; |
| 540 | ggml_cuda_flash_attn_ext_vec_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst); |
| 541 | } |
| 542 | return; |
| 543 | } |
| 544 | |
| 545 | constexpr int cols_per_block = 2; |
| 546 | if (logit_softcap == 0.0f) { |
| 547 | constexpr bool use_logit_softcap = false; |
| 548 | ggml_cuda_flash_attn_ext_vec_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst); |
| 549 | } else { |
| 550 | constexpr bool use_logit_softcap = true; |
| 551 | ggml_cuda_flash_attn_ext_vec_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst); |
| 552 | } |
| 553 | } |
| 554 | |
| 555 | #define DECL_FATTN_VEC_CASE(D, type_K, type_V) \ |
| 556 | template void ggml_cuda_flash_attn_ext_vec_case \ |
| 557 | <D, type_K, type_V>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \ |
| 558 | |
| 559 | #define EXTERN_DECL_FATTN_VEC_CASES(D, type_K) \ |
| 560 | extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_F16); \ |
| 561 | extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q4_0); \ |
| 562 | extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q4_1); \ |
| 563 | extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q5_0); \ |
| 564 | extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q5_1); \ |
| 565 | extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q8_0); \ |
| 566 | |
| 567 | EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_F16) |
| 568 | EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q4_0) |
| 569 | EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q4_1) |
| 570 | EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q5_0) |
| 571 | EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q5_1) |
| 572 | EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q8_0) |
| 573 | |
| 574 | EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_F16) |
| 575 | EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q4_0) |
| 576 | EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q4_1) |
| 577 | EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q5_0) |
| 578 | EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q5_1) |
| 579 | EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q8_0) |
| 580 | |
| 581 | EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_F16) |
| 582 | EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q4_0) |
| 583 | EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q4_1) |
| 584 | EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_0) |
| 585 | EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_1) |
| 586 | EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q8_0) |
| 587 | |