1#pragma once
2
3#include "common.cuh"
4
5#if (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA)
6#define GGML_USE_WMMA_FATTN
7#endif // (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA)
8
9#if defined(GGML_HIP_ROCWMMA_FATTN)
10#if defined(CDNA) && (ROCWMMA_VERSION_MAJOR < 2 || ROCWMMA_VERSION_MINOR > 0 || ROCWMMA_VERSION_PATCH > 0)
11#define GGML_USE_WMMA_FATTN
12#elif defined(CDNA)
13#warning "rocwmma fattn on CDNA is broken on rocwmma v2.0.0, expect degraded performance"
14#endif // defined(CDNA) && (ROCWMMA_VERSION_MAJOR < 2 || ROCWMMA_VERSION_MINOR > 0 || ROCWMMA_VERSION_PATCH > 0)
15#if defined(RDNA3)
16#define GGML_USE_WMMA_FATTN
17#endif // defined(RDNA3)
18#if defined(RDNA4) && ROCWMMA_VERSION_MAJOR > 1
19#define GGML_USE_WMMA_FATTN
20#elif defined(RDNA4)
21#warning "rocwmma fattn is not suported on RDNA4 on rocwmma < v2.0.0, expect degraded performance"
22#endif // defined(RDNA4) && ROCWMMA_VERSION_MAJOR > 1
23#endif // defined(GGML_HIP_ROCWMMA_FATTN)
24
25// WMMA flash attention requires FP16 matrix instructions to be available for ggml code.
26static bool ggml_cuda_should_use_wmma_fattn(const int cc) {
27#if defined(GGML_USE_HIP) && !defined(GGML_HIP_ROCWMMA_FATTN)
28 return false;
29#else
30 if ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(arch: cc) == GGML_CUDA_CC_VOLTA) ||
31 GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_MTHREADS(cc)) {
32 return true;
33 } else if (GGML_CUDA_CC_IS_CDNA(cc)){
34#if defined(GGML_HIP_ROCWMMA_FATTN) && (ROCWMMA_VERSION_MAJOR < 2 || ROCWMMA_VERSION_MINOR > 0 || ROCWMMA_VERSION_PATCH > 0)
35 return true;
36#else
37 return false;
38#endif // defined(GGML_HIP_ROCWMMA_FATTN) (ROCWMMA_VERSION_MAJOR < 2 || ROCWMMA_VERSION_MINOR > 0 || ROCWMMA_VERSION_PATCH > 0)
39 } else if (GGML_CUDA_CC_IS_RDNA4(cc)) {
40#if defined(GGML_HIP_ROCWMMA_FATTN) && ROCWMMA_VERSION_MAJOR > 1
41 return true;
42#else
43 return false;
44#endif // defined(GGML_HIP_ROCWMMA_FATTN) && ROCWMMA_VERSION_MAJOR > 1
45 } else {
46 return false;
47 }
48#endif // defined(GGML_USE_HIP) && !defined(GGML_HIP_ROCWMMA_FATTN)
49}
50
51void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
52