| 1 | #pragma once |
| 2 | |
| 3 | #include "common.cuh" |
| 4 | |
| 5 | #if (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA) |
| 6 | #define GGML_USE_WMMA_FATTN |
| 7 | #endif // (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA) |
| 8 | |
| 9 | #if defined(GGML_HIP_ROCWMMA_FATTN) |
| 10 | #if defined(CDNA) && (ROCWMMA_VERSION_MAJOR < 2 || ROCWMMA_VERSION_MINOR > 0 || ROCWMMA_VERSION_PATCH > 0) |
| 11 | #define GGML_USE_WMMA_FATTN |
| 12 | #elif defined(CDNA) |
| 13 | #warning "rocwmma fattn on CDNA is broken on rocwmma v2.0.0, expect degraded performance" |
| 14 | #endif // defined(CDNA) && (ROCWMMA_VERSION_MAJOR < 2 || ROCWMMA_VERSION_MINOR > 0 || ROCWMMA_VERSION_PATCH > 0) |
| 15 | #if defined(RDNA3) |
| 16 | #define GGML_USE_WMMA_FATTN |
| 17 | #endif // defined(RDNA3) |
| 18 | #if defined(RDNA4) && ROCWMMA_VERSION_MAJOR > 1 |
| 19 | #define GGML_USE_WMMA_FATTN |
| 20 | #elif defined(RDNA4) |
| 21 | #warning "rocwmma fattn is not suported on RDNA4 on rocwmma < v2.0.0, expect degraded performance" |
| 22 | #endif // defined(RDNA4) && ROCWMMA_VERSION_MAJOR > 1 |
| 23 | #endif // defined(GGML_HIP_ROCWMMA_FATTN) |
| 24 | |
| 25 | // WMMA flash attention requires FP16 matrix instructions to be available for ggml code. |
| 26 | static bool ggml_cuda_should_use_wmma_fattn(const int cc) { |
| 27 | #if defined(GGML_USE_HIP) && !defined(GGML_HIP_ROCWMMA_FATTN) |
| 28 | return false; |
| 29 | #else |
| 30 | if ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(arch: cc) == GGML_CUDA_CC_VOLTA) || |
| 31 | GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_MTHREADS(cc)) { |
| 32 | return true; |
| 33 | } else if (GGML_CUDA_CC_IS_CDNA(cc)){ |
| 34 | #if defined(GGML_HIP_ROCWMMA_FATTN) && (ROCWMMA_VERSION_MAJOR < 2 || ROCWMMA_VERSION_MINOR > 0 || ROCWMMA_VERSION_PATCH > 0) |
| 35 | return true; |
| 36 | #else |
| 37 | return false; |
| 38 | #endif // defined(GGML_HIP_ROCWMMA_FATTN) (ROCWMMA_VERSION_MAJOR < 2 || ROCWMMA_VERSION_MINOR > 0 || ROCWMMA_VERSION_PATCH > 0) |
| 39 | } else if (GGML_CUDA_CC_IS_RDNA4(cc)) { |
| 40 | #if defined(GGML_HIP_ROCWMMA_FATTN) && ROCWMMA_VERSION_MAJOR > 1 |
| 41 | return true; |
| 42 | #else |
| 43 | return false; |
| 44 | #endif // defined(GGML_HIP_ROCWMMA_FATTN) && ROCWMMA_VERSION_MAJOR > 1 |
| 45 | } else { |
| 46 | return false; |
| 47 | } |
| 48 | #endif // defined(GGML_USE_HIP) && !defined(GGML_HIP_ROCWMMA_FATTN) |
| 49 | } |
| 50 | |
| 51 | void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst); |
| 52 | |