| 1 | #include "ggml-vulkan.h" |
| 2 | #include <vulkan/vulkan_core.h> |
| 3 | #if defined(GGML_VULKAN_RUN_TESTS) || defined(GGML_VULKAN_CHECK_RESULTS) |
| 4 | #include <chrono> |
| 5 | #include "ggml-cpu.h" |
| 6 | #endif |
| 7 | |
| 8 | // See https://github.com/KhronosGroup/Vulkan-Hpp?tab=readme-ov-file#extensions--per-device-function-pointers- |
| 9 | #define VULKAN_HPP_DISPATCH_LOADER_DYNAMIC 1 |
| 10 | // We use VULKAN_HPP_DEFAULT_DISPATCHER, but not VULKAN_HPP_DEFAULT_DISPATCH_LOADER_DYNAMIC_STORAGE |
| 11 | // to avoid conflicts with applications or other libraries who might use it. |
| 12 | #if VK_HEADER_VERSION >= 301 |
| 13 | namespace vk::detail { class DispatchLoaderDynamic; } |
| 14 | using vk::detail::DispatchLoaderDynamic; |
| 15 | #else |
| 16 | namespace vk { class DispatchLoaderDynamic; } |
| 17 | using vk::DispatchLoaderDynamic; |
| 18 | #endif |
| 19 | DispatchLoaderDynamic & ggml_vk_default_dispatcher(); |
| 20 | #define VULKAN_HPP_DEFAULT_DISPATCHER ggml_vk_default_dispatcher() |
| 21 | |
| 22 | #include <vulkan/vulkan.hpp> |
| 23 | |
| 24 | #include <algorithm> |
| 25 | #include <cmath> |
| 26 | #include <iomanip> |
| 27 | #include <iostream> |
| 28 | #include <tuple> |
| 29 | #include <vector> |
| 30 | #include <sstream> |
| 31 | #include <utility> |
| 32 | #include <memory> |
| 33 | #include <limits> |
| 34 | #include <map> |
| 35 | #include <unordered_map> |
| 36 | #include <memory> |
| 37 | #include <mutex> |
| 38 | #include <future> |
| 39 | #include <thread> |
| 40 | |
| 41 | #if defined(_MSC_VER) |
| 42 | # define NOMINMAX 1 |
| 43 | # include <windows.h> |
| 44 | # define YIELD() YieldProcessor() |
| 45 | #elif defined(__clang__) || defined(__GNUC__) |
| 46 | # if defined(__x86_64__) ||defined(__i386__) |
| 47 | # include <immintrin.h> |
| 48 | # define YIELD() _mm_pause() |
| 49 | # elif defined(__arm__) || defined(__aarch64__) |
| 50 | # if defined(__clang__) |
| 51 | # include <arm_acle.h> |
| 52 | # define YIELD() __yield() |
| 53 | # else |
| 54 | # define YIELD() asm volatile("yield") |
| 55 | # endif |
| 56 | # endif |
| 57 | #endif |
| 58 | |
| 59 | #if !defined(YIELD) |
| 60 | #define YIELD() |
| 61 | #endif |
| 62 | |
| 63 | #include "ggml-impl.h" |
| 64 | #include "ggml-backend-impl.h" |
| 65 | |
| 66 | #include "ggml-vulkan-shaders.hpp" |
| 67 | |
| 68 | // remove this once it's more widely available in the SDK |
| 69 | #if !defined(VK_KHR_shader_bfloat16) |
| 70 | |
| 71 | #define VK_KHR_shader_bfloat16 1 |
| 72 | #define VK_KHR_SHADER_BFLOAT16_SPEC_VERSION 1 |
| 73 | #define VK_KHR_SHADER_BFLOAT16_EXTENSION_NAME "VK_KHR_shader_bfloat16" |
| 74 | #define VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_BFLOAT16_FEATURES_KHR ((VkStructureType)1000141000) |
| 75 | #define VK_COMPONENT_TYPE_BFLOAT16_KHR ((VkComponentTypeKHR)1000141000) |
| 76 | |
| 77 | typedef struct VkPhysicalDeviceShaderBfloat16FeaturesKHR { |
| 78 | VkStructureType sType; |
| 79 | void* pNext; |
| 80 | VkBool32 shaderBFloat16Type; |
| 81 | VkBool32 shaderBFloat16DotProduct; |
| 82 | VkBool32 shaderBFloat16CooperativeMatrix; |
| 83 | } VkPhysicalDeviceShaderBfloat16FeaturesKHR; |
| 84 | #endif |
| 85 | |
| 86 | #define ROUNDUP_POW2(M, N) (((M) + (N) - 1) & ~((N) - 1)) |
| 87 | #define CEIL_DIV(M, N) (((M) + (N)-1) / (N)) |
| 88 | static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; } |
| 89 | |
| 90 | #define VK_VENDOR_ID_AMD 0x1002 |
| 91 | #define VK_VENDOR_ID_APPLE 0x106b |
| 92 | #define VK_VENDOR_ID_INTEL 0x8086 |
| 93 | #define VK_VENDOR_ID_NVIDIA 0x10de |
| 94 | |
| 95 | #define VK_DEVICE_DESCRIPTOR_POOL_SIZE 256 |
| 96 | |
| 97 | #define GGML_VK_MAX_NODES 8192 |
| 98 | |
| 99 | #define VK_CHECK(err, msg) \ |
| 100 | do { \ |
| 101 | vk::Result err_ = (err); \ |
| 102 | if (err_ != vk::Result::eSuccess) { \ |
| 103 | fprintf(stderr, "ggml_vulkan: %s error %s at %s:%d\n", \ |
| 104 | #err, to_string(err_).c_str(), __FILE__, __LINE__); \ |
| 105 | exit(1); \ |
| 106 | } \ |
| 107 | } while (0) |
| 108 | |
| 109 | #ifdef GGML_VULKAN_DEBUG |
| 110 | #define VK_LOG_DEBUG(msg) std::cerr << msg << std::endl |
| 111 | #else |
| 112 | #define VK_LOG_DEBUG(msg) ((void) 0) |
| 113 | #endif // GGML_VULKAN_DEBUG |
| 114 | |
| 115 | struct ggml_backend_vk_context; |
| 116 | |
| 117 | #define MAX_PARAMETER_COUNT 12 |
| 118 | // Max number of adds that can be fused without exceeding MAX_PARAMETER_COUNT. |
| 119 | #define MAX_FUSED_ADDS (MAX_PARAMETER_COUNT - 3) |
| 120 | |
| 121 | struct vk_pipeline_struct { |
| 122 | std::string name; |
| 123 | vk::ShaderModule shader_module; |
| 124 | vk::PipelineLayout layout; |
| 125 | vk::Pipeline pipeline; |
| 126 | uint32_t push_constant_size; |
| 127 | uint32_t parameter_count; |
| 128 | std::array<uint32_t, 3> wg_denoms; |
| 129 | uint32_t align; |
| 130 | // true if fields have been set by ggml_vk_create_pipeline |
| 131 | bool initialized {}; |
| 132 | // set to true to request the pipeline is compiled |
| 133 | std::atomic<bool> needed {}; |
| 134 | // set to true when the shader has been compiled |
| 135 | std::atomic<bool> compiled {}; |
| 136 | // number of registers used, extracted from pipeline executable properties |
| 137 | uint32_t register_count {}; |
| 138 | }; |
| 139 | |
| 140 | typedef std::shared_ptr<vk_pipeline_struct> vk_pipeline; |
| 141 | typedef std::weak_ptr<vk_pipeline_struct> vk_pipeline_ref; |
| 142 | |
| 143 | static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline); |
| 144 | |
| 145 | struct vk_matmul_pipeline_struct { |
| 146 | vk_pipeline l, m, s; |
| 147 | vk_pipeline a_l, a_m, a_s; |
| 148 | // Returns true when all unaligned pipelines are null. |
| 149 | // We only check for unaligned variants since one of the unaligned pipelines must exist |
| 150 | // while aligned pipelines are optional |
| 151 | bool is_empty() const { |
| 152 | return l == nullptr && m == nullptr && s == nullptr; |
| 153 | } |
| 154 | }; |
| 155 | typedef std::shared_ptr<vk_matmul_pipeline_struct> vk_matmul_pipeline; |
| 156 | |
| 157 | struct vk_matmul_pipeline2 { |
| 158 | vk_matmul_pipeline2() { |
| 159 | f16acc = std::make_shared<vk_matmul_pipeline_struct>(); |
| 160 | f32acc = std::make_shared<vk_matmul_pipeline_struct>(); |
| 161 | } |
| 162 | vk_matmul_pipeline f32acc; |
| 163 | vk_matmul_pipeline f16acc; |
| 164 | }; |
| 165 | |
| 166 | struct vk_device_struct; |
| 167 | typedef std::shared_ptr<vk_device_struct> vk_device; |
| 168 | typedef std::weak_ptr<vk_device_struct> vk_device_ref; |
| 169 | |
| 170 | struct vk_buffer_struct; |
| 171 | typedef std::shared_ptr<vk_buffer_struct> vk_buffer; |
| 172 | typedef std::weak_ptr<vk_buffer_struct> vk_buffer_ref; |
| 173 | |
| 174 | struct ggml_backend_vk_buffer_type_context { |
| 175 | std::string name; |
| 176 | vk_device device; |
| 177 | }; |
| 178 | |
| 179 | struct vk_queue; |
| 180 | |
| 181 | // Stores command pool/buffers. There's an instance of this |
| 182 | // for each (context,queue) pair and for each (device,queue) pair. |
| 183 | struct vk_command_pool { |
| 184 | void init(vk_device& device, vk_queue *q_); |
| 185 | void destroy(vk::Device& device); |
| 186 | |
| 187 | vk::CommandPool pool; |
| 188 | uint32_t cmd_buffer_idx; |
| 189 | std::vector<vk::CommandBuffer> cmd_buffers; |
| 190 | |
| 191 | vk_queue *q; |
| 192 | }; |
| 193 | |
| 194 | // Prevent simultaneous submissions to the same queue. |
| 195 | // This could be per vk_queue if we stopped having two vk_queue structures |
| 196 | // sharing the same vk::Queue. |
| 197 | static std::mutex queue_mutex; |
| 198 | |
| 199 | struct vk_queue { |
| 200 | uint32_t queue_family_index; |
| 201 | vk::Queue queue; |
| 202 | |
| 203 | vk_command_pool cmd_pool; |
| 204 | |
| 205 | vk::PipelineStageFlags stage_flags; |
| 206 | |
| 207 | bool transfer_only; |
| 208 | |
| 209 | // copy everything except the cmd_pool |
| 210 | void copyFrom(vk_queue &other) { |
| 211 | queue_family_index = other.queue_family_index; |
| 212 | queue = other.queue; |
| 213 | stage_flags = other.stage_flags; |
| 214 | transfer_only = other.transfer_only; |
| 215 | } |
| 216 | }; |
| 217 | |
| 218 | static const char * ggml_backend_vk_buffer_type_name(ggml_backend_buffer_type_t buft); |
| 219 | static ggml_backend_buffer_t ggml_backend_vk_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size); |
| 220 | static size_t ggml_backend_vk_buffer_type_get_alignment(ggml_backend_buffer_type_t buft); |
| 221 | static size_t ggml_backend_vk_buffer_type_get_max_size(ggml_backend_buffer_type_t buft); |
| 222 | static size_t ggml_backend_vk_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor); |
| 223 | static ggml_backend_buffer_type_i ggml_backend_vk_buffer_type_interface = { |
| 224 | /* .get_name = */ ggml_backend_vk_buffer_type_name, |
| 225 | /* .alloc_buffer = */ ggml_backend_vk_buffer_type_alloc_buffer, |
| 226 | /* .get_alignment = */ ggml_backend_vk_buffer_type_get_alignment, |
| 227 | /* .get_max_size = */ ggml_backend_vk_buffer_type_get_max_size, |
| 228 | /* .get_alloc_size = */ ggml_backend_vk_buffer_type_get_alloc_size, |
| 229 | /* .is_host = */ NULL, |
| 230 | }; |
| 231 | |
| 232 | #ifdef GGML_VULKAN_MEMORY_DEBUG |
| 233 | class vk_memory_logger; |
| 234 | #endif |
| 235 | class vk_perf_logger; |
| 236 | static void ggml_vk_destroy_buffer(vk_buffer& buf); |
| 237 | |
| 238 | static constexpr uint32_t mul_mat_vec_max_cols = 8; |
| 239 | static constexpr uint32_t p021_max_gqa_ratio = 8; |
| 240 | |
| 241 | enum vk_device_architecture { |
| 242 | OTHER, |
| 243 | AMD_GCN, |
| 244 | AMD_RDNA1, |
| 245 | AMD_RDNA2, |
| 246 | AMD_RDNA3, |
| 247 | INTEL_XE2, |
| 248 | NVIDIA_PRE_TURING, |
| 249 | }; |
| 250 | |
| 251 | static vk_device_architecture get_device_architecture(const vk::PhysicalDevice& device) { |
| 252 | vk::PhysicalDeviceProperties props = device.getProperties(); |
| 253 | |
| 254 | if (props.vendorID == VK_VENDOR_ID_AMD) { |
| 255 | const std::vector<vk::ExtensionProperties> ext_props = device.enumerateDeviceExtensionProperties(); |
| 256 | |
| 257 | bool amd_shader_core_properties = false; |
| 258 | bool integer_dot_product = false; |
| 259 | bool subgroup_size_control = false; |
| 260 | |
| 261 | for (const auto& properties : ext_props) { |
| 262 | if (strcmp(s1: "VK_AMD_shader_core_properties" , s2: properties.extensionName) == 0) { |
| 263 | amd_shader_core_properties = true; |
| 264 | } else if (strcmp(s1: "VK_KHR_shader_integer_dot_product" , s2: properties.extensionName) == 0) { |
| 265 | integer_dot_product = true; |
| 266 | } else if (strcmp(s1: "VK_EXT_subgroup_size_control" , s2: properties.extensionName) == 0) { |
| 267 | subgroup_size_control = true; |
| 268 | } |
| 269 | } |
| 270 | |
| 271 | if (!amd_shader_core_properties || !integer_dot_product || !subgroup_size_control) { |
| 272 | return vk_device_architecture::OTHER; |
| 273 | } |
| 274 | |
| 275 | vk::PhysicalDeviceProperties2 props2; |
| 276 | vk::PhysicalDeviceShaderCorePropertiesAMD shader_core_props_amd; |
| 277 | vk::PhysicalDeviceShaderIntegerDotProductPropertiesKHR integer_dot_props; |
| 278 | vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props; |
| 279 | |
| 280 | props2.pNext = &shader_core_props_amd; |
| 281 | shader_core_props_amd.pNext = &integer_dot_props; |
| 282 | integer_dot_props.pNext = &subgroup_size_control_props; |
| 283 | |
| 284 | device.getProperties2(pProperties: &props2); |
| 285 | |
| 286 | if (subgroup_size_control_props.maxSubgroupSize == 64 && subgroup_size_control_props.minSubgroupSize == 64) { |
| 287 | return vk_device_architecture::AMD_GCN; |
| 288 | } |
| 289 | if (subgroup_size_control_props.maxSubgroupSize == 64 && subgroup_size_control_props.minSubgroupSize == 32) { |
| 290 | // RDNA |
| 291 | if (shader_core_props_amd.wavefrontsPerSimd == 20) { |
| 292 | return vk_device_architecture::AMD_RDNA1; |
| 293 | } |
| 294 | if (integer_dot_props.integerDotProduct4x8BitPackedMixedSignednessAccelerated) { |
| 295 | return vk_device_architecture::AMD_RDNA3; |
| 296 | } |
| 297 | return vk_device_architecture::AMD_RDNA2; |
| 298 | } |
| 299 | } else if (props.vendorID == VK_VENDOR_ID_INTEL) { |
| 300 | const std::vector<vk::ExtensionProperties> ext_props = device.enumerateDeviceExtensionProperties(); |
| 301 | |
| 302 | bool subgroup_size_control = false; |
| 303 | |
| 304 | for (const auto& properties : ext_props) { |
| 305 | if (strcmp(s1: "VK_EXT_subgroup_size_control" , s2: properties.extensionName) == 0) { |
| 306 | subgroup_size_control = true; |
| 307 | } |
| 308 | } |
| 309 | |
| 310 | if (!subgroup_size_control) { |
| 311 | return vk_device_architecture::OTHER; |
| 312 | } |
| 313 | |
| 314 | vk::PhysicalDeviceProperties2 props2; |
| 315 | vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props; |
| 316 | |
| 317 | props2.pNext = &subgroup_size_control_props; |
| 318 | device.getProperties2(pProperties: &props2); |
| 319 | |
| 320 | if (subgroup_size_control_props.minSubgroupSize == 16) { |
| 321 | // Xe2 architecture uses SIMD16 while previous Xe and Gen architecture uses SIMD8. |
| 322 | // Minimum subgroup size matches the SIMD width so we distinguish architecture by checking this value. |
| 323 | // https://www.intel.com/content/www/us/en/content-details/824434/2024-intel-tech-tour-xe2-and-lunar-lake-s-gpu.html |
| 324 | // https://www.intel.com/content/www/us/en/docs/oneapi/optimization-guide-gpu/2025-0/intel-xe-gpu-architecture.html |
| 325 | return vk_device_architecture::INTEL_XE2; |
| 326 | } |
| 327 | } else if (props.vendorID == VK_VENDOR_ID_NVIDIA) { |
| 328 | const std::vector<vk::ExtensionProperties> ext_props = device.enumerateDeviceExtensionProperties(); |
| 329 | |
| 330 | bool cooperative_matrix = false; |
| 331 | |
| 332 | // Detect "pre-turing" based on lack of coopmat support. |
| 333 | for (const auto& properties : ext_props) { |
| 334 | if (strcmp(s1: "VK_KHR_cooperative_matrix" , s2: properties.extensionName) == 0) { |
| 335 | cooperative_matrix = true; |
| 336 | break; |
| 337 | } |
| 338 | } |
| 339 | |
| 340 | if (!cooperative_matrix) { |
| 341 | return vk_device_architecture::NVIDIA_PRE_TURING; |
| 342 | } |
| 343 | } |
| 344 | return vk_device_architecture::OTHER; |
| 345 | } |
| 346 | |
| 347 | enum vk_conv_shapes { |
| 348 | CONV_SHAPE_128x128, |
| 349 | CONV_SHAPE_64x32, |
| 350 | CONV_SHAPE_32x256, |
| 351 | CONV_SHAPE_COUNT, |
| 352 | }; |
| 353 | |
| 354 | enum dmmv_wg_sizes { |
| 355 | DMMV_WG_SIZE_SUBGROUP, |
| 356 | DMMV_WG_SIZE_LARGE, |
| 357 | DMMV_WG_SIZE_COUNT, |
| 358 | }; |
| 359 | |
| 360 | enum FaCodePath { |
| 361 | FA_SCALAR, |
| 362 | FA_COOPMAT1, |
| 363 | FA_COOPMAT2, |
| 364 | }; |
| 365 | |
| 366 | struct vk_fa_pipeline_state { |
| 367 | vk_fa_pipeline_state(uint32_t HSK, uint32_t HSV, bool small_rows, FaCodePath path, bool aligned, bool f32acc) |
| 368 | : HSK(HSK), HSV(HSV), small_rows(small_rows), path(path), aligned(aligned), f32acc(f32acc) {} |
| 369 | |
| 370 | uint32_t HSK, HSV; |
| 371 | bool small_rows; |
| 372 | FaCodePath path; |
| 373 | bool aligned; |
| 374 | bool f32acc; |
| 375 | |
| 376 | bool operator<(const vk_fa_pipeline_state &b) const { |
| 377 | return std::tie(args: HSK, args: HSV, args: small_rows, args: path, args: aligned, args: f32acc) < |
| 378 | std::tie(args: b.HSK, args: b.HSV, args: b.small_rows, args: b.path, args: b.aligned, args: b.f32acc); |
| 379 | } |
| 380 | }; |
| 381 | |
| 382 | enum shader_reduction_mode { |
| 383 | SHADER_REDUCTION_MODE_SHMEM, |
| 384 | SHADER_REDUCTION_MODE_HYBRID, |
| 385 | SHADER_REDUCTION_MODE_SUBGROUP, |
| 386 | SHADER_REDUCTION_MODE_COUNT, |
| 387 | }; |
| 388 | |
| 389 | static constexpr uint32_t num_argsort_pipelines = 11; |
| 390 | static constexpr uint32_t max_argsort_cols = 1 << (num_argsort_pipelines-1); |
| 391 | static constexpr uint32_t num_topk_moe_pipelines = 10; |
| 392 | |
| 393 | static constexpr std::initializer_list<ggml_op> topk_moe_early_softmax_norm{ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT, |
| 394 | GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE, |
| 395 | GGML_OP_SUM_ROWS, GGML_OP_CLAMP, GGML_OP_DIV, |
| 396 | GGML_OP_RESHAPE }; |
| 397 | static constexpr std::initializer_list<ggml_op> topk_moe_early_softmax { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT, |
| 398 | GGML_OP_VIEW, GGML_OP_GET_ROWS }; |
| 399 | static constexpr std::initializer_list<ggml_op> topk_moe_late_softmax { GGML_OP_ARGSORT, GGML_OP_VIEW, |
| 400 | GGML_OP_GET_ROWS, GGML_OP_RESHAPE, |
| 401 | GGML_OP_SOFT_MAX, GGML_OP_RESHAPE }; |
| 402 | |
| 403 | //node #978 ( SOFT_MAX): ffn_moe_probs-15 ( 0K) [Vulka ] use=2: ffn_moe_logits-15 ( 0K) [Vulka ] |
| 404 | //node #979 ( RESHAPE): ffn_moe_probs-15 (re ( 0K) [Vulka ] use=1: ffn_moe_probs-15 ( 0K) [Vulka ] |
| 405 | //node #980 ( ARGSORT): ffn_moe_argsort-15 ( 0K) [Vulka ] use=1: ffn_moe_probs-15 ( 0K) [Vulka ] |
| 406 | //node #981 ( VIEW): ffn_moe_topk-15 ( 0K) [Vulka ] use=4: ffn_moe_argsort-15 ( 0K) [Vulka ] |
| 407 | //node #982 ( GET_ROWS): ffn_moe_weights-15 ( 0K) [Vulka ] use=1: ffn_moe_probs-15 (re ( 0K) [Vulka ] ffn_moe_topk-15 ( 0K) [Vulka ] |
| 408 | //node #983 ( RESHAPE): ffn_moe_weights-15 ( ( 0K) [Vulka ] use=2: ffn_moe_weights-15 ( 0K) [Vulka ] |
| 409 | //node #984 ( SUM_ROWS): ffn_moe_weights_sum- ( 0K) [Vulka ] use=1: ffn_moe_weights-15 ( ( 0K) [Vulka ] |
| 410 | //node #985 ( CLAMP): ffn_moe_weights_sum_ ( 0K) [Vulka ] use=1: ffn_moe_weights_sum- ( 0K) [Vulka ] |
| 411 | //node #986 ( DIV): ffn_moe_weights_norm ( 0K) [Vulka ] use=1: ffn_moe_weights-15 ( ( 0K) [Vulka ] ffn_moe_weights_sum_ ( 0K) [Vulka ] |
| 412 | //node #987 ( RESHAPE): ffn_moe_weights_norm ( 0K) [Vulka ] use=1: ffn_moe_weights_norm ( 0K) [Vulka ] |
| 413 | static constexpr std::initializer_list<std::array<int, 3>> topk_moe_early_softmax_norm_edges { |
| 414 | { 1, 0, 0 }, // reshape->src[0] == softmax |
| 415 | { 2, 0, 0 }, // argsort->src[0] == softmax |
| 416 | { 3, 0, 2 }, // view->src[0] == argsort |
| 417 | { 4, 0, 1 }, // get_rows->src[0] == reshape |
| 418 | { 4, 1, 3 }, // get_rows->src[1] == view |
| 419 | { 5, 0, 4 }, // reshape->src[0] == get_rows |
| 420 | { 6, 0, 5 }, // sum_rows->src[0] == reshape |
| 421 | { 7, 0, 6 }, // clamp->src[0] == sum_rows |
| 422 | { 8, 0, 5 }, // div->src[0] == reshape |
| 423 | { 8, 1, 7 }, // div->src[1] == clamp |
| 424 | { 9, 0, 8 }, // reshape->src[0] == div |
| 425 | }; |
| 426 | |
| 427 | // same as early_softmax_norm but ending after the get_rows |
| 428 | static constexpr std::initializer_list<std::array<int, 3>> topk_moe_early_softmax_edges { |
| 429 | { 1, 0, 0 }, // reshape->src[0] == softmax |
| 430 | { 2, 0, 0 }, // argsort->src[0] == softmax |
| 431 | { 3, 0, 2 }, // view->src[0] == argsort |
| 432 | { 4, 0, 1 }, // get_rows->src[0] == reshape |
| 433 | { 4, 1, 3 }, // get_rows->src[1] == view |
| 434 | }; |
| 435 | |
| 436 | //node #652 ( ARGSORT): ffn_moe_argsort-11 ( 0K) [Vulka ] use=1: ffn_moe_probs-11 ( 0K) [Vulka ] |
| 437 | //node #653 ( VIEW): ffn_moe_topk-11 ( 0K) [Vulka ] use=7: ffn_moe_argsort-11 ( 0K) [Vulka ] |
| 438 | //node #654 ( GET_ROWS): ffn_moe_weights-11 ( 0K) [Vulka ] use=1: ffn_moe_probs-11 (re ( 0K) [Vulka ] ffn_moe_topk-11 ( 0K) [Vulka ] |
| 439 | //node #655 ( RESHAPE): ffn_moe_weights-11 ( ( 0K) [Vulka ] use=1: ffn_moe_weights-11 ( 0K) [Vulka ] |
| 440 | //node #656 ( SOFT_MAX): node_656 ( 0K) [Vulka ] use=1: ffn_moe_weights-11 ( ( 0K) [Vulka ] |
| 441 | //node #657 ( RESHAPE): ffn_moe_weights_soft ( 0K) [Vulka ] use=1: node_656 ( 0K) [Vulka ] |
| 442 | static constexpr std::initializer_list<std::array<int, 3>> topk_moe_late_softmax_edges { |
| 443 | { 1, 0, 0 }, // view->src[0] == argsort |
| 444 | { 2, 1, 1 }, // get_rows->src[1] == view |
| 445 | { 3, 0, 2 }, // reshape->src[0] == get_rows |
| 446 | { 4, 0, 3 }, // soft_max->src[0] == reshape |
| 447 | { 5, 0, 4 }, // reshape->src[0] == soft_max |
| 448 | }; |
| 449 | |
| 450 | enum topk_moe_mode { |
| 451 | TOPK_MOE_EARLY_SOFTMAX, |
| 452 | TOPK_MOE_EARLY_SOFTMAX_NORM, |
| 453 | TOPK_MOE_LATE_SOFTMAX, |
| 454 | TOPK_MOE_COUNT, |
| 455 | }; |
| 456 | |
| 457 | static topk_moe_mode ggml_vk_num_additional_ops_to_topk_moe_mode(uint32_t num) { |
| 458 | topk_moe_mode mode = num == topk_moe_early_softmax_norm.size() - 1 ? TOPK_MOE_EARLY_SOFTMAX_NORM : |
| 459 | num == topk_moe_early_softmax.size() - 1 ? TOPK_MOE_EARLY_SOFTMAX : |
| 460 | TOPK_MOE_LATE_SOFTMAX; |
| 461 | return mode; |
| 462 | } |
| 463 | |
| 464 | static constexpr std::initializer_list<std::array<int, 3>> rope_view_set_rows_edges { |
| 465 | { 1, 0, 0 }, // view->src[0] == rope |
| 466 | { 2, 0, 1 }, // set_rows->src[0] == view |
| 467 | }; |
| 468 | |
| 469 | static constexpr std::initializer_list<std::array<int, 3>> rms_norm_mul_rope_view_set_rows_edges { |
| 470 | { 1, 0, 0 }, // mul->src[0] == rms |
| 471 | { 2, 0, 1 }, // rope->src[0] == mul |
| 472 | { 3, 0, 2 }, // view->src[0] == rope |
| 473 | { 4, 0, 3 }, // set_rows->src[0] == view |
| 474 | }; |
| 475 | |
| 476 | |
| 477 | struct vk_device_struct { |
| 478 | std::recursive_mutex mutex; |
| 479 | |
| 480 | vk::PhysicalDevice physical_device; |
| 481 | vk::PhysicalDeviceProperties properties; |
| 482 | std::string name; |
| 483 | uint64_t max_memory_allocation_size; |
| 484 | uint64_t max_buffer_size; |
| 485 | uint64_t suballocation_block_size; |
| 486 | bool fp16; |
| 487 | bool bf16; |
| 488 | bool pipeline_robustness; |
| 489 | vk::Device device; |
| 490 | uint32_t vendor_id; |
| 491 | vk::DriverId driver_id; |
| 492 | vk_device_architecture architecture; |
| 493 | vk_queue compute_queue; |
| 494 | vk_queue transfer_queue; |
| 495 | bool single_queue; |
| 496 | uint32_t subgroup_size; |
| 497 | uint32_t shader_core_count; |
| 498 | bool uma; |
| 499 | bool prefer_host_memory; |
| 500 | bool float_controls_rte_fp16; |
| 501 | bool subgroup_arithmetic; |
| 502 | bool subgroup_shuffle; |
| 503 | bool subgroup_ballot; |
| 504 | bool subgroup_clustered; |
| 505 | bool multi_add; |
| 506 | bool shader_int64; |
| 507 | bool buffer_device_address; |
| 508 | |
| 509 | bool add_rms_fusion; |
| 510 | uint32_t partials_binding_alignment; |
| 511 | |
| 512 | bool integer_dot_product; |
| 513 | // 0: default, 1: force mmvq, -1: disable mmvq |
| 514 | int32_t mmvq_mode; |
| 515 | |
| 516 | bool subgroup_size_control; |
| 517 | uint32_t subgroup_min_size; |
| 518 | uint32_t subgroup_max_size; |
| 519 | bool subgroup_require_full_support; |
| 520 | |
| 521 | bool coopmat_support; |
| 522 | bool coopmat_acc_f32_support {}; |
| 523 | bool coopmat_acc_f16_support {}; |
| 524 | bool coopmat_bf16_support {}; |
| 525 | bool coopmat_support_16x16x16_f16acc {}; |
| 526 | bool coopmat_support_16x16x16_f32acc {}; |
| 527 | bool coopmat1_fa_support {}; |
| 528 | uint32_t coopmat_m; |
| 529 | uint32_t coopmat_n; |
| 530 | uint32_t coopmat_k; |
| 531 | |
| 532 | bool coopmat_int_support; |
| 533 | uint32_t coopmat_int_m; |
| 534 | uint32_t coopmat_int_n; |
| 535 | uint32_t coopmat_int_k; |
| 536 | |
| 537 | bool coopmat2; |
| 538 | |
| 539 | bool pipeline_executable_properties_support {}; |
| 540 | |
| 541 | size_t idx; |
| 542 | |
| 543 | bool mul_mat_l[GGML_TYPE_COUNT]; |
| 544 | bool mul_mat_m[GGML_TYPE_COUNT]; |
| 545 | bool mul_mat_s[GGML_TYPE_COUNT]; |
| 546 | bool mul_mat_id_l[GGML_TYPE_COUNT]; |
| 547 | bool mul_mat_id_m[GGML_TYPE_COUNT]; |
| 548 | bool mul_mat_id_s[GGML_TYPE_COUNT]; |
| 549 | |
| 550 | vk::DescriptorSetLayout dsl; |
| 551 | |
| 552 | vk_matmul_pipeline pipeline_matmul_f32 {}; |
| 553 | vk_matmul_pipeline pipeline_matmul_f32_f16 {}; |
| 554 | vk_matmul_pipeline pipeline_matmul_bf16 {}; |
| 555 | vk_matmul_pipeline2 pipeline_matmul_f16; |
| 556 | vk_matmul_pipeline2 pipeline_matmul_f16_f32; |
| 557 | |
| 558 | vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat[GGML_TYPE_COUNT]; |
| 559 | vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_COUNT]; |
| 560 | vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_COUNT]; |
| 561 | |
| 562 | vk_matmul_pipeline pipeline_matmul_id_f32 {}; |
| 563 | vk_matmul_pipeline pipeline_matmul_id_bf16 {}; |
| 564 | vk_matmul_pipeline2 pipeline_matmul_id_f16; |
| 565 | vk_matmul_pipeline2 pipeline_matmul_id_f16_f32; |
| 566 | |
| 567 | vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_id[GGML_TYPE_COUNT]; |
| 568 | vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_COUNT]; |
| 569 | |
| 570 | vk_pipeline pipeline_matmul_split_k_reduce; |
| 571 | vk_pipeline pipeline_quantize_q8_1; |
| 572 | vk_pipeline pipeline_quantize_q8_1_x4; |
| 573 | |
| 574 | vk_pipeline pipeline_dequant[GGML_TYPE_COUNT]; |
| 575 | vk_pipeline pipeline_dequant_mul_mat_vec_f32_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT][mul_mat_vec_max_cols]; |
| 576 | vk_pipeline pipeline_dequant_mul_mat_vec_f16_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT][mul_mat_vec_max_cols]; |
| 577 | vk_pipeline pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_COUNT]; |
| 578 | |
| 579 | vk_pipeline pipeline_dequant_mul_mat_vec_q8_1_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT][mul_mat_vec_max_cols]; |
| 580 | |
| 581 | vk_pipeline pipeline_mul_mat_vec_p021_f16_f32[p021_max_gqa_ratio]; |
| 582 | vk_pipeline pipeline_mul_mat_vec_nc_f16_f32; |
| 583 | vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT]; |
| 584 | vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT]; |
| 585 | vk_pipeline pipeline_acc_f32; |
| 586 | |
| 587 | // [src0 0=fp32,1=fp16][src1 0=fp32,1=fp16][dst 0=fp32,1=fp16] |
| 588 | vk_pipeline pipeline_add[2][2][2]; |
| 589 | vk_pipeline pipeline_add_norepeat[2][2][2]; |
| 590 | vk_pipeline pipeline_sub[2][2][2]; |
| 591 | vk_pipeline pipeline_sub_norepeat[2][2][2]; |
| 592 | vk_pipeline pipeline_mul[2][2][2]; |
| 593 | vk_pipeline pipeline_mul_norepeat[2][2][2]; |
| 594 | vk_pipeline pipeline_div[2][2][2]; |
| 595 | vk_pipeline pipeline_div_norepeat[2][2][2]; |
| 596 | vk_pipeline pipeline_add_rms[2][2][2]; |
| 597 | vk_pipeline pipeline_add_rms_norepeat[2][2][2]; |
| 598 | |
| 599 | // indexed by num_additional_fused_ops == num_adds - 1 |
| 600 | vk_pipeline pipeline_multi_add[MAX_FUSED_ADDS]; |
| 601 | vk_pipeline pipeline_multi_add_rms[MAX_FUSED_ADDS]; |
| 602 | |
| 603 | vk_pipeline pipeline_add_id_f32; |
| 604 | |
| 605 | vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32; |
| 606 | vk_pipeline pipeline_upscale_nearest_f32, pipeline_upscale_bilinear_f32; |
| 607 | vk_pipeline pipeline_scale_f32; |
| 608 | vk_pipeline pipeline_sqr_f32; |
| 609 | vk_pipeline pipeline_sqrt_f32; |
| 610 | vk_pipeline pipeline_sin_f32; |
| 611 | vk_pipeline pipeline_cos_f32; |
| 612 | vk_pipeline pipeline_clamp_f32; |
| 613 | vk_pipeline pipeline_pad_f32; |
| 614 | vk_pipeline pipeline_roll_f32; |
| 615 | vk_pipeline pipeline_repeat_f32, pipeline_repeat_back_f32; |
| 616 | vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16, pipeline_cpy_f16_f32, pipeline_cpy_f32_bf16, pipeline_cpy_f32_i32, pipeline_cpy_i32_f32; |
| 617 | vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16, pipeline_contig_cpy_f32_i32, pipeline_contig_cpy_i32_f32; |
| 618 | vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT]; |
| 619 | vk_pipeline pipeline_cpy_quant_f32[GGML_TYPE_COUNT]; |
| 620 | vk_pipeline pipeline_set_rows_i32[GGML_TYPE_COUNT]; |
| 621 | vk_pipeline pipeline_set_rows_i64[GGML_TYPE_COUNT]; |
| 622 | vk_pipeline pipeline_norm_f32; |
| 623 | vk_pipeline pipeline_group_norm_f32; |
| 624 | vk_pipeline pipeline_rms_norm_f32; |
| 625 | vk_pipeline pipeline_rms_norm_mul_f32; |
| 626 | vk_pipeline pipeline_rms_norm_partials_f32; |
| 627 | vk_pipeline pipeline_rms_norm_mul_partials_f32; |
| 628 | vk_pipeline pipeline_rms_norm_mul_rope_f32_f32; |
| 629 | vk_pipeline pipeline_rms_norm_mul_rope_f32_f16; |
| 630 | vk_pipeline pipeline_rms_norm_back_f32; |
| 631 | vk_pipeline pipeline_l2_norm_f32; |
| 632 | |
| 633 | // [src/dst 0=fp32,1=fp16] |
| 634 | vk_pipeline pipeline_exp[2]; |
| 635 | vk_pipeline pipeline_gelu[2]; |
| 636 | vk_pipeline pipeline_gelu_erf[2]; |
| 637 | vk_pipeline pipeline_gelu_quick[2]; |
| 638 | vk_pipeline pipeline_silu[2]; |
| 639 | vk_pipeline pipeline_relu[2]; |
| 640 | vk_pipeline pipeline_tanh[2]; |
| 641 | vk_pipeline pipeline_sigmoid[2]; |
| 642 | vk_pipeline pipeline_hardsigmoid[2]; |
| 643 | vk_pipeline pipeline_hardswish[2]; |
| 644 | |
| 645 | vk_pipeline pipeline_geglu[2]; |
| 646 | vk_pipeline pipeline_reglu[2]; |
| 647 | vk_pipeline pipeline_swiglu[2]; |
| 648 | vk_pipeline pipeline_swiglu_oai[2]; |
| 649 | vk_pipeline pipeline_geglu_erf[2]; |
| 650 | vk_pipeline pipeline_geglu_quick[2]; |
| 651 | |
| 652 | vk_pipeline pipeline_leaky_relu_f32; |
| 653 | vk_pipeline pipeline_silu_back_f32; |
| 654 | vk_pipeline pipeline_diag_mask_inf_f32; |
| 655 | vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16; |
| 656 | vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512; |
| 657 | vk_pipeline pipeline_soft_max_back_f32; |
| 658 | vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16, pipeline_rope_norm_f32_f16; |
| 659 | vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16, pipeline_rope_neox_f32_f16; |
| 660 | vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16; |
| 661 | vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16; |
| 662 | vk_pipeline pipeline_argsort_f32[num_argsort_pipelines]; |
| 663 | vk_pipeline pipeline_sum_rows_f32; |
| 664 | vk_pipeline pipeline_argmax_f32; |
| 665 | vk_pipeline pipeline_count_equal_i32; |
| 666 | vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16; |
| 667 | vk_pipeline pipeline_im2col_3d_f32, pipeline_im2col_3d_f32_f16; |
| 668 | vk_pipeline pipeline_timestep_embedding_f32; |
| 669 | vk_pipeline pipeline_conv_transpose_1d_f32; |
| 670 | vk_pipeline pipeline_pool2d_f32; |
| 671 | vk_pipeline pipeline_rwkv_wkv6_f32; |
| 672 | vk_pipeline pipeline_rwkv_wkv7_f32; |
| 673 | vk_pipeline pipeline_ssm_scan_f32_d128; |
| 674 | vk_pipeline pipeline_ssm_scan_f32_d256; |
| 675 | vk_pipeline pipeline_ssm_conv_f32; |
| 676 | vk_pipeline pipeline_opt_step_adamw_f32; |
| 677 | vk_pipeline pipeline_opt_step_sgd_f32; |
| 678 | vk_pipeline pipeline_conv2d_f32[CONV_SHAPE_COUNT]; |
| 679 | vk_pipeline pipeline_conv2d_f16_f32[CONV_SHAPE_COUNT]; |
| 680 | vk_pipeline pipeline_conv_transpose_2d_f32[CONV_SHAPE_COUNT]; |
| 681 | vk_pipeline pipeline_conv_transpose_2d_f16_f32[CONV_SHAPE_COUNT]; |
| 682 | vk_pipeline pipeline_conv2d_dw_whcn_f32, pipeline_conv2d_dw_whcn_f16_f32; |
| 683 | vk_pipeline pipeline_conv2d_dw_cwhn_f32, pipeline_conv2d_dw_cwhn_f16_f32; |
| 684 | |
| 685 | std::map<vk_fa_pipeline_state, vk_pipeline> pipeline_flash_attn_f32_f16[GGML_TYPE_COUNT]; |
| 686 | |
| 687 | vk_pipeline pipeline_flash_attn_split_k_reduce; |
| 688 | |
| 689 | vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][TOPK_MOE_COUNT]; |
| 690 | |
| 691 | std::vector<vk_pipeline_ref> all_pipelines; |
| 692 | |
| 693 | std::vector<std::tuple<void*, size_t, vk_buffer>> pinned_memory; |
| 694 | |
| 695 | vk::Fence fence; |
| 696 | vk_buffer sync_staging; |
| 697 | |
| 698 | ggml_backend_buffer_type buffer_type; |
| 699 | |
| 700 | bool disable_fusion; |
| 701 | bool disable_host_visible_vidmem; |
| 702 | bool allow_sysmem_fallback; |
| 703 | bool disable_graph_optimize; |
| 704 | |
| 705 | #ifdef GGML_VULKAN_MEMORY_DEBUG |
| 706 | std::unique_ptr<vk_memory_logger> memory_logger; |
| 707 | #endif |
| 708 | |
| 709 | // for GGML_VK_PERF_LOGGER |
| 710 | std::unique_ptr<vk_perf_logger> perf_logger; |
| 711 | vk::QueryPool query_pool; |
| 712 | int32_t num_queries; |
| 713 | |
| 714 | ~vk_device_struct() { |
| 715 | VK_LOG_DEBUG("destroy device " << name); |
| 716 | |
| 717 | device.destroyFence(fence); |
| 718 | |
| 719 | ggml_vk_destroy_buffer(buf&: sync_staging); |
| 720 | |
| 721 | compute_queue.cmd_pool.destroy(device); |
| 722 | transfer_queue.cmd_pool.destroy(device); |
| 723 | |
| 724 | for (auto& pipeline : all_pipelines) { |
| 725 | if (pipeline.expired()) { |
| 726 | continue; |
| 727 | } |
| 728 | |
| 729 | vk_pipeline pl = pipeline.lock(); |
| 730 | ggml_vk_destroy_pipeline(device, pipeline&: pl); |
| 731 | } |
| 732 | all_pipelines.clear(); |
| 733 | |
| 734 | device.destroyDescriptorSetLayout(descriptorSetLayout: dsl); |
| 735 | |
| 736 | device.destroy(); |
| 737 | } |
| 738 | }; |
| 739 | |
| 740 | void vk_command_pool::init(vk_device& device, vk_queue *q_) { |
| 741 | cmd_buffer_idx = 0; |
| 742 | q = q_; |
| 743 | |
| 744 | vk::CommandPoolCreateInfo command_pool_create_info(vk::CommandPoolCreateFlags(VK_COMMAND_POOL_CREATE_TRANSIENT_BIT), q->queue_family_index); |
| 745 | pool = device->device.createCommandPool(createInfo: command_pool_create_info); |
| 746 | } |
| 747 | |
| 748 | void vk_command_pool::destroy(vk::Device& device) { |
| 749 | device.destroyCommandPool(commandPool: pool); |
| 750 | pool = nullptr; |
| 751 | cmd_buffers.clear(); |
| 752 | } |
| 753 | |
| 754 | struct vk_buffer_struct { |
| 755 | vk::Buffer buffer = VK_NULL_HANDLE; |
| 756 | vk::DeviceMemory device_memory = VK_NULL_HANDLE; |
| 757 | vk::MemoryPropertyFlags memory_property_flags; |
| 758 | void * ptr; |
| 759 | size_t size = 0; |
| 760 | vk::DeviceAddress bda_addr {}; |
| 761 | |
| 762 | vk_device device; |
| 763 | |
| 764 | ~vk_buffer_struct() { |
| 765 | if (size == 0) { |
| 766 | return; |
| 767 | } |
| 768 | VK_LOG_DEBUG("~vk_buffer_struct(" << buffer << ", " << size << ")" ); |
| 769 | |
| 770 | device->device.freeMemory(memory: device_memory); |
| 771 | device->device.destroyBuffer(buffer); |
| 772 | } |
| 773 | }; |
| 774 | |
| 775 | struct vk_subbuffer { |
| 776 | vk_buffer buffer; |
| 777 | uint64_t offset; |
| 778 | uint64_t size; |
| 779 | |
| 780 | operator vk::DescriptorBufferInfo() const { |
| 781 | return { buffer->buffer, offset, size }; |
| 782 | } |
| 783 | }; |
| 784 | |
| 785 | struct vk_semaphore { |
| 786 | vk::Semaphore s; |
| 787 | uint64_t value; |
| 788 | }; |
| 789 | |
| 790 | struct vk_submission { |
| 791 | vk::CommandBuffer buffer; |
| 792 | std::vector<vk_semaphore> wait_semaphores; |
| 793 | std::vector<vk_semaphore> signal_semaphores; |
| 794 | }; |
| 795 | |
| 796 | typedef std::vector<vk_submission> vk_sequence; |
| 797 | |
| 798 | struct vk_mat_mat_push_constants { |
| 799 | uint32_t M; uint32_t N; uint32_t K; |
| 800 | uint32_t stride_a; uint32_t stride_b; uint32_t stride_d; |
| 801 | uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d; |
| 802 | uint32_t k_split; |
| 803 | uint32_t ne02; uint32_t ne12; uint32_t broadcast2; uint32_t broadcast3; |
| 804 | uint32_t padded_N; |
| 805 | }; |
| 806 | struct vk_mat_vec_push_constants { |
| 807 | uint32_t ncols; |
| 808 | uint32_t stride_a; |
| 809 | uint32_t stride_b; |
| 810 | uint32_t stride_d; |
| 811 | uint32_t batch_stride_a; |
| 812 | uint32_t batch_stride_b; |
| 813 | uint32_t batch_stride_d; |
| 814 | uint32_t enable_bias; |
| 815 | uint32_t ne02; |
| 816 | uint32_t ne12; |
| 817 | uint32_t broadcast2; |
| 818 | uint32_t broadcast3; |
| 819 | }; |
| 820 | |
| 821 | struct vk_mat_mat_id_push_constants { |
| 822 | uint32_t M; uint32_t N; uint32_t K; |
| 823 | uint32_t stride_a; uint32_t stride_b; uint32_t stride_d; |
| 824 | uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d; |
| 825 | uint32_t nei0; uint32_t nei1; uint32_t nbi1; uint32_t ne11; |
| 826 | uint32_t padded_N; |
| 827 | }; |
| 828 | struct vk_mat_vec_id_push_constants { |
| 829 | uint32_t ncols; |
| 830 | uint32_t stride_a; |
| 831 | uint32_t stride_b; |
| 832 | uint32_t stride_d; |
| 833 | uint32_t batch_stride_a; |
| 834 | uint32_t batch_stride_b; |
| 835 | uint32_t batch_stride_d; |
| 836 | uint32_t enable_bias; |
| 837 | uint32_t nei0; |
| 838 | uint32_t ne11; |
| 839 | }; |
| 840 | |
| 841 | struct vk_flash_attn_push_constants { |
| 842 | uint32_t N; |
| 843 | uint32_t KV; |
| 844 | |
| 845 | uint32_t ne1; |
| 846 | uint32_t ne2; |
| 847 | uint32_t ne3; |
| 848 | |
| 849 | uint32_t neq2; |
| 850 | uint32_t neq3; |
| 851 | uint32_t nek2; |
| 852 | uint32_t nek3; |
| 853 | uint32_t nev2; |
| 854 | uint32_t nev3; |
| 855 | uint32_t nem1; |
| 856 | uint32_t nem2; |
| 857 | uint32_t nem3; |
| 858 | |
| 859 | uint32_t nb01; |
| 860 | uint32_t nb02; |
| 861 | uint32_t nb03; |
| 862 | uint32_t nb11; |
| 863 | uint32_t nb12; |
| 864 | uint32_t nb13; |
| 865 | uint32_t nb21; |
| 866 | uint32_t nb22; |
| 867 | uint32_t nb23; |
| 868 | |
| 869 | float scale; |
| 870 | float max_bias; |
| 871 | float logit_softcap; |
| 872 | |
| 873 | uint32_t mask_n_head_log2; |
| 874 | float m0; |
| 875 | float m1; |
| 876 | |
| 877 | uint32_t gqa_ratio; |
| 878 | uint32_t split_kv; |
| 879 | uint32_t k_num; |
| 880 | }; |
| 881 | static_assert(sizeof(vk_flash_attn_push_constants) <= 128, "sizeof(vk_flash_attn_push_constants) must be <= 128" ); |
| 882 | |
| 883 | struct vk_op_push_constants { |
| 884 | uint32_t KX; |
| 885 | uint32_t KY; |
| 886 | float param1; |
| 887 | float param2; |
| 888 | }; |
| 889 | |
| 890 | struct vk_op_glu_push_constants { |
| 891 | uint32_t N; |
| 892 | uint32_t ne00; |
| 893 | uint32_t ne20; |
| 894 | uint32_t mode; // 0: default, 1: swapped, 2: split |
| 895 | float alpha; // for swiglu_oai |
| 896 | float limit; |
| 897 | }; |
| 898 | |
| 899 | struct vk_op_unary_push_constants { |
| 900 | uint32_t ne; |
| 901 | uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03; |
| 902 | uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13; |
| 903 | uint32_t misalign_offsets; |
| 904 | float param1; float param2; |
| 905 | uint32_t ne0_012mp; uint32_t ne0_012L; |
| 906 | uint32_t ne0_01mp; uint32_t ne0_01L; |
| 907 | uint32_t ne0_0mp; uint32_t ne0_0L; |
| 908 | uint32_t ne1_012mp; uint32_t ne1_012L; |
| 909 | uint32_t ne1_01mp; uint32_t ne1_01L; |
| 910 | uint32_t ne1_0mp; uint32_t ne1_0L; |
| 911 | }; |
| 912 | static_assert(sizeof(vk_op_unary_push_constants) <= 128, "sizeof(vk_op_unary_push_constants) must be <= 128" ); |
| 913 | |
| 914 | static vk_op_unary_push_constants vk_op_unary_push_constants_init(const ggml_tensor * src0, const ggml_tensor * dst, int64_t ne = 0) { |
| 915 | GGML_ASSERT(ne != 0 || (ggml_nelements(src0) == ggml_nelements(dst))); |
| 916 | ne = ne != 0 ? ne : ggml_nelements(tensor: dst); |
| 917 | GGML_ASSERT(ne <= (int64_t)std::numeric_limits<uint32_t>::max()); |
| 918 | |
| 919 | vk_op_unary_push_constants p{}; |
| 920 | p.ne = (uint32_t)ne; |
| 921 | |
| 922 | size_t src0_tsize = ggml_type_size(type: src0->type); |
| 923 | p.ne00 = (uint32_t)src0->ne[0]; |
| 924 | p.ne01 = (uint32_t)src0->ne[1]; |
| 925 | p.ne02 = (uint32_t)src0->ne[2]; |
| 926 | p.ne03 = (uint32_t)src0->ne[3]; |
| 927 | p.nb00 = (uint32_t)(src0->nb[0] / src0_tsize); |
| 928 | p.nb01 = (uint32_t)(src0->nb[1] / src0_tsize); |
| 929 | p.nb02 = (uint32_t)(src0->nb[2] / src0_tsize); |
| 930 | p.nb03 = (uint32_t)(src0->nb[3] / src0_tsize); |
| 931 | |
| 932 | size_t dst_tsize = ggml_type_size(type: dst->type); |
| 933 | p.ne10 = (uint32_t)dst->ne[0]; |
| 934 | p.ne11 = (uint32_t)dst->ne[1]; |
| 935 | p.ne12 = (uint32_t)dst->ne[2]; |
| 936 | p.ne13 = (uint32_t)dst->ne[3]; |
| 937 | p.nb10 = (uint32_t)(dst->nb[0] / dst_tsize); |
| 938 | p.nb11 = (uint32_t)(dst->nb[1] / dst_tsize); |
| 939 | p.nb12 = (uint32_t)(dst->nb[2] / dst_tsize); |
| 940 | p.nb13 = (uint32_t)(dst->nb[3] / dst_tsize); |
| 941 | |
| 942 | return p; // offsets are initialized later in ggml_vk_op |
| 943 | } |
| 944 | |
| 945 | struct vk_op_pad_push_constants { |
| 946 | uint32_t ne; |
| 947 | uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03; |
| 948 | uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13; |
| 949 | uint32_t misalign_offsets; |
| 950 | |
| 951 | uint32_t lp0; uint32_t rp0; |
| 952 | uint32_t lp1; uint32_t rp1; |
| 953 | uint32_t lp2; uint32_t rp2; |
| 954 | uint32_t lp3; uint32_t rp3; |
| 955 | }; |
| 956 | |
| 957 | static vk_op_pad_push_constants vk_op_pad_push_constants_init(const ggml_tensor * src0, const ggml_tensor * dst) { |
| 958 | int64_t ne = ggml_nelements(tensor: dst); |
| 959 | GGML_ASSERT(ne <= (int64_t)std::numeric_limits<uint32_t>::max()); |
| 960 | |
| 961 | vk_op_pad_push_constants p{}; |
| 962 | p.ne = (uint32_t)ne; |
| 963 | |
| 964 | size_t src0_tsize = ggml_type_size(type: src0->type); |
| 965 | p.ne00 = (uint32_t)src0->ne[0]; |
| 966 | p.ne01 = (uint32_t)src0->ne[1]; |
| 967 | p.ne02 = (uint32_t)src0->ne[2]; |
| 968 | p.ne03 = (uint32_t)src0->ne[3]; |
| 969 | p.nb00 = (uint32_t)(src0->nb[0] / src0_tsize); |
| 970 | p.nb01 = (uint32_t)(src0->nb[1] / src0_tsize); |
| 971 | p.nb02 = (uint32_t)(src0->nb[2] / src0_tsize); |
| 972 | p.nb03 = (uint32_t)(src0->nb[3] / src0_tsize); |
| 973 | |
| 974 | size_t dst_tsize = ggml_type_size(type: dst->type); |
| 975 | p.ne10 = (uint32_t)dst->ne[0]; |
| 976 | p.ne11 = (uint32_t)dst->ne[1]; |
| 977 | p.ne12 = (uint32_t)dst->ne[2]; |
| 978 | p.ne13 = (uint32_t)dst->ne[3]; |
| 979 | p.nb10 = (uint32_t)(dst->nb[0] / dst_tsize); |
| 980 | p.nb11 = (uint32_t)(dst->nb[1] / dst_tsize); |
| 981 | p.nb12 = (uint32_t)(dst->nb[2] / dst_tsize); |
| 982 | p.nb13 = (uint32_t)(dst->nb[3] / dst_tsize); |
| 983 | |
| 984 | p.lp0 = dst->op_params[0]; |
| 985 | p.rp0 = dst->op_params[1]; |
| 986 | p.lp1 = dst->op_params[2]; |
| 987 | p.rp1 = dst->op_params[3]; |
| 988 | p.lp2 = dst->op_params[4]; |
| 989 | p.rp2 = dst->op_params[5]; |
| 990 | p.lp3 = dst->op_params[6]; |
| 991 | p.rp3 = dst->op_params[7]; |
| 992 | |
| 993 | return p; // fastdiv values and offsets are initialized later in ggml_vk_op |
| 994 | } |
| 995 | |
| 996 | // See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1. |
| 997 | // Precompute mp (m' in the paper) and L such that division |
| 998 | // can be computed using a multiply (high 32b of 64b result) |
| 999 | // and a shift: |
| 1000 | // |
| 1001 | // n/d = (mulhi(n, mp) + n) >> L; |
| 1002 | static void init_fastdiv_values(uint32_t d, uint32_t &mp, uint32_t &L) |
| 1003 | { |
| 1004 | // compute L = ceil(log2(d)); |
| 1005 | L = 0; |
| 1006 | while (L < 32 && (uint32_t{1} << L) < d) { |
| 1007 | L++; |
| 1008 | } |
| 1009 | |
| 1010 | mp = (uint32_t)((uint64_t{1} << 32) * ((uint64_t{1} << L) - d) / d + 1); |
| 1011 | } |
| 1012 | |
| 1013 | template <typename T> void init_pushconst_fastdiv(T &p) { |
| 1014 | GGML_UNUSED(p); |
| 1015 | static_assert(!std::is_const<T>::value, "unexpected type" ); |
| 1016 | } |
| 1017 | |
| 1018 | template <> void init_pushconst_fastdiv(vk_op_unary_push_constants &p) { |
| 1019 | // Compute magic values to divide by these six numbers. |
| 1020 | init_fastdiv_values(d: p.ne02*p.ne01*p.ne00, mp&: p.ne0_012mp, L&: p.ne0_012L); |
| 1021 | init_fastdiv_values(d: p.ne01*p.ne00, mp&: p.ne0_01mp, L&: p.ne0_01L); |
| 1022 | init_fastdiv_values(d: p.ne00, mp&: p.ne0_0mp, L&: p.ne0_0L); |
| 1023 | init_fastdiv_values(d: p.ne12*p.ne11*p.ne10, mp&: p.ne1_012mp, L&: p.ne1_012L); |
| 1024 | init_fastdiv_values(d: p.ne11*p.ne10, mp&: p.ne1_01mp, L&: p.ne1_01L); |
| 1025 | init_fastdiv_values(d: p.ne10, mp&: p.ne1_0mp, L&: p.ne1_0L); |
| 1026 | } |
| 1027 | |
| 1028 | struct vk_op_binary_push_constants { |
| 1029 | uint32_t ne; |
| 1030 | uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03; |
| 1031 | uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13; |
| 1032 | uint32_t ne20; uint32_t ne21; uint32_t ne22; uint32_t ne23; uint32_t nb20; uint32_t nb21; uint32_t nb22; uint32_t nb23; |
| 1033 | uint32_t misalign_offsets; |
| 1034 | float param1; float param2; int32_t param3; |
| 1035 | }; |
| 1036 | |
| 1037 | struct vk_op_multi_add_push_constants { |
| 1038 | // shape for dst |
| 1039 | uint32_t ne20; uint32_t ne21; uint32_t ne22; uint32_t ne23; |
| 1040 | |
| 1041 | // strides for srcs+dst |
| 1042 | uint32_t nb[MAX_PARAMETER_COUNT][4]; |
| 1043 | |
| 1044 | uint32_t rms_partials; |
| 1045 | }; |
| 1046 | // update multi_add.comp if this changes |
| 1047 | static_assert(MAX_PARAMETER_COUNT == 12); |
| 1048 | static_assert(sizeof(vk_op_multi_add_push_constants) <= 256); |
| 1049 | |
| 1050 | struct vk_op_topk_moe_push_constants { |
| 1051 | uint32_t n_rows; |
| 1052 | uint32_t n_expert_used; |
| 1053 | float clamp_min; |
| 1054 | float clamp_max; |
| 1055 | }; |
| 1056 | |
| 1057 | struct vk_op_add_id_push_constants { |
| 1058 | uint32_t ne0; |
| 1059 | uint32_t ne1; |
| 1060 | uint32_t s01; |
| 1061 | uint32_t s02; |
| 1062 | uint32_t s11; |
| 1063 | uint32_t s21; |
| 1064 | }; |
| 1065 | |
| 1066 | struct vk_op_diag_mask_push_constants { |
| 1067 | uint32_t ncols; |
| 1068 | uint32_t rows_per_channel; |
| 1069 | int32_t n_past; |
| 1070 | }; |
| 1071 | |
| 1072 | struct vk_op_rope_push_constants { |
| 1073 | uint32_t rope_mode; |
| 1074 | uint32_t ncols; |
| 1075 | uint32_t n_dims; |
| 1076 | float freq_scale; |
| 1077 | uint32_t p_delta_rows; |
| 1078 | float freq_base; |
| 1079 | float ext_factor; |
| 1080 | float attn_factor; |
| 1081 | float corr_dims[2]; |
| 1082 | float theta_scale; |
| 1083 | uint32_t has_ff; |
| 1084 | uint32_t ne02; |
| 1085 | uint32_t s1; |
| 1086 | uint32_t s2; |
| 1087 | int32_t sections[4]; |
| 1088 | uint32_t is_imrope; |
| 1089 | uint32_t is_back; |
| 1090 | uint32_t set_rows_stride; |
| 1091 | }; |
| 1092 | |
| 1093 | // For fused rms_norm+mul+rope(+view+set_rows) |
| 1094 | struct vk_op_rms_norm_mul_rope_push_constants { |
| 1095 | vk_op_binary_push_constants bin; |
| 1096 | vk_op_rope_push_constants rope; |
| 1097 | }; |
| 1098 | |
| 1099 | struct vk_op_soft_max_push_constants { |
| 1100 | uint32_t KX; |
| 1101 | uint32_t KY; |
| 1102 | uint32_t ne00; |
| 1103 | uint32_t ne01; |
| 1104 | uint32_t ne02; |
| 1105 | uint32_t ne12; |
| 1106 | uint32_t ne13; |
| 1107 | uint32_t nb11; |
| 1108 | uint32_t nb12; |
| 1109 | uint32_t nb13; |
| 1110 | float scale; |
| 1111 | float max_bias; |
| 1112 | float m0; |
| 1113 | float m1; |
| 1114 | uint32_t n_head_log2; |
| 1115 | uint32_t nrows_x; |
| 1116 | uint32_t has_sinks; |
| 1117 | }; |
| 1118 | |
| 1119 | struct vk_op_argsort_push_constants { |
| 1120 | uint32_t ncols; |
| 1121 | uint32_t nrows; |
| 1122 | int32_t order; |
| 1123 | }; |
| 1124 | |
| 1125 | struct vk_op_im2col_push_constants { |
| 1126 | uint64_t dst_addr; |
| 1127 | uint32_t batch_offset; uint32_t offset_delta; |
| 1128 | uint32_t IC; |
| 1129 | uint32_t IW; uint32_t IH; |
| 1130 | uint32_t OW; uint32_t OH; |
| 1131 | uint32_t KW; uint32_t KH; |
| 1132 | uint32_t pelements; |
| 1133 | uint32_t CHW; |
| 1134 | int32_t s0; int32_t s1; |
| 1135 | int32_t p0; int32_t p1; |
| 1136 | int32_t d0; int32_t d1; |
| 1137 | }; |
| 1138 | |
| 1139 | struct vk_op_im2col_3d_push_constants { |
| 1140 | uint64_t dst_addr; |
| 1141 | uint32_t nb10; |
| 1142 | uint32_t nb11; |
| 1143 | uint32_t nb12; |
| 1144 | uint32_t nb13; |
| 1145 | uint32_t s0; |
| 1146 | uint32_t s1; |
| 1147 | uint32_t s2; |
| 1148 | uint32_t p0; |
| 1149 | uint32_t p1; |
| 1150 | uint32_t p2; |
| 1151 | uint32_t d0; |
| 1152 | uint32_t d1; |
| 1153 | uint32_t d2; |
| 1154 | uint32_t IW; |
| 1155 | uint32_t IH; |
| 1156 | uint32_t ID; |
| 1157 | uint32_t IC; |
| 1158 | uint32_t KW; |
| 1159 | uint32_t OH; |
| 1160 | uint32_t KD_KH_KW; |
| 1161 | uint32_t KH_KW; |
| 1162 | uint32_t IC_KD_KH_KW; |
| 1163 | uint32_t N_OD_OH; |
| 1164 | uint32_t OD_OH; |
| 1165 | uint32_t OD_OH_OW_IC_KD_KH_KW; |
| 1166 | uint32_t OH_OW_IC_KD_KH_KW; |
| 1167 | uint32_t OW_IC_KD_KH_KW; |
| 1168 | uint32_t misalign_offsets; |
| 1169 | }; |
| 1170 | |
| 1171 | struct vk_op_timestep_embedding_push_constants { |
| 1172 | uint32_t nb1; |
| 1173 | uint32_t dim; |
| 1174 | uint32_t max_period; |
| 1175 | }; |
| 1176 | |
| 1177 | struct vk_op_conv_transpose_1d_push_constants { |
| 1178 | uint32_t Cout; |
| 1179 | uint32_t Cin; |
| 1180 | uint32_t K; |
| 1181 | uint32_t L; |
| 1182 | uint32_t KL; |
| 1183 | |
| 1184 | uint32_t nb01; |
| 1185 | uint32_t nb02; |
| 1186 | uint32_t nb11; |
| 1187 | uint32_t nb1; |
| 1188 | |
| 1189 | int32_t s0; |
| 1190 | }; |
| 1191 | |
| 1192 | struct vk_op_pool2d_push_constants { |
| 1193 | uint32_t IW; uint32_t IH; |
| 1194 | uint32_t OW; uint32_t OH; |
| 1195 | uint32_t OC; |
| 1196 | uint32_t pelements; |
| 1197 | uint32_t op; |
| 1198 | int32_t k0; int32_t k1; |
| 1199 | int32_t s0; int32_t s1; |
| 1200 | int32_t p0; int32_t p1; |
| 1201 | }; |
| 1202 | |
| 1203 | struct vk_op_rwkv_wkv6_push_constants { |
| 1204 | uint32_t B; |
| 1205 | uint32_t T; |
| 1206 | uint32_t C; |
| 1207 | uint32_t H; |
| 1208 | }; |
| 1209 | |
| 1210 | struct vk_op_rwkv_wkv7_push_constants { |
| 1211 | uint32_t B; |
| 1212 | uint32_t T; |
| 1213 | uint32_t C; |
| 1214 | uint32_t H; |
| 1215 | }; |
| 1216 | struct vk_op_ssm_scan_push_constants { |
| 1217 | uint32_t nb02, nb03, nb12, nb13; |
| 1218 | uint32_t nb21, nb22, nb31; |
| 1219 | uint32_t nb42, nb43, nb52, nb53; |
| 1220 | uint32_t s_off; |
| 1221 | uint32_t n_head, d_head, n_group, n_tok; |
| 1222 | }; |
| 1223 | struct vk_op_ssm_conv_push_constants { |
| 1224 | uint32_t nb01, nb02; |
| 1225 | uint32_t nb11; |
| 1226 | uint32_t dst_nb0, dst_nb1, dst_nb2; |
| 1227 | uint32_t nc, ncs, nr, n_t, n_s; |
| 1228 | }; |
| 1229 | |
| 1230 | struct vk_op_conv2d_push_constants { |
| 1231 | uint32_t Cout; |
| 1232 | uint32_t Cin; |
| 1233 | uint32_t N; |
| 1234 | |
| 1235 | uint32_t KW; |
| 1236 | uint32_t KH; |
| 1237 | uint32_t W; |
| 1238 | uint32_t H; |
| 1239 | uint32_t OW; |
| 1240 | uint32_t OH; |
| 1241 | |
| 1242 | uint32_t s0; |
| 1243 | uint32_t s1; |
| 1244 | uint32_t p0; |
| 1245 | uint32_t p1; |
| 1246 | uint32_t d0; |
| 1247 | uint32_t d1; |
| 1248 | |
| 1249 | uint32_t nb01; |
| 1250 | uint32_t nb02; |
| 1251 | uint32_t nb03; |
| 1252 | |
| 1253 | uint32_t nb11; |
| 1254 | uint32_t nb12; |
| 1255 | uint32_t nb13; |
| 1256 | |
| 1257 | uint32_t nb1; |
| 1258 | uint32_t nb2; |
| 1259 | uint32_t nb3; |
| 1260 | |
| 1261 | // init_fastdiv_values constants for dividing by KW, KW*KH, OW, OW*OH |
| 1262 | uint32_t KWmp; uint32_t KWL; |
| 1263 | uint32_t KWKHmp; uint32_t KWKHL; |
| 1264 | uint32_t OWmp; uint32_t OWL; |
| 1265 | uint32_t OWOHmp; uint32_t OWOHL; |
| 1266 | }; |
| 1267 | |
| 1268 | template <> void init_pushconst_fastdiv(vk_op_conv2d_push_constants &p) { |
| 1269 | // Compute magic values to divide by KW, KW*KH, OW, OW*OH |
| 1270 | init_fastdiv_values(d: p.KW, mp&: p.KWmp, L&: p.KWL); |
| 1271 | init_fastdiv_values(d: p.KW*p.KH, mp&: p.KWKHmp, L&: p.KWKHL); |
| 1272 | init_fastdiv_values(d: p.OW, mp&: p.OWmp, L&: p.OWL); |
| 1273 | init_fastdiv_values(d: p.OW*p.OH, mp&: p.OWOHmp, L&: p.OWOHL); |
| 1274 | } |
| 1275 | |
| 1276 | struct vk_op_conv_transpose_2d_push_constants { |
| 1277 | uint32_t Cout; |
| 1278 | uint32_t Cin; |
| 1279 | uint32_t N; |
| 1280 | |
| 1281 | uint32_t KW; |
| 1282 | uint32_t KH; |
| 1283 | uint32_t W; |
| 1284 | uint32_t H; |
| 1285 | uint32_t OW; |
| 1286 | uint32_t OH; |
| 1287 | |
| 1288 | uint32_t s0; |
| 1289 | uint32_t s1; |
| 1290 | uint32_t p0; |
| 1291 | uint32_t p1; |
| 1292 | uint32_t d0; |
| 1293 | uint32_t d1; |
| 1294 | |
| 1295 | uint32_t nb01; |
| 1296 | uint32_t nb02; |
| 1297 | uint32_t nb03; |
| 1298 | |
| 1299 | uint32_t nb11; |
| 1300 | uint32_t nb12; |
| 1301 | uint32_t nb13; |
| 1302 | |
| 1303 | uint32_t nb1; |
| 1304 | uint32_t nb2; |
| 1305 | uint32_t nb3; |
| 1306 | |
| 1307 | // init_fastdiv_values constants for dividing by KW, KW*KH, OW, OW*OH, s0, s1 |
| 1308 | uint32_t KWmp; uint32_t KWL; |
| 1309 | uint32_t KWKHmp; uint32_t KWKHL; |
| 1310 | uint32_t OWmp; uint32_t OWL; |
| 1311 | uint32_t OWOHmp; uint32_t OWOHL; |
| 1312 | uint32_t s0mp; uint32_t s0L; |
| 1313 | uint32_t s1mp; uint32_t s1L; |
| 1314 | }; |
| 1315 | |
| 1316 | template <> void init_pushconst_fastdiv(vk_op_conv_transpose_2d_push_constants &p) { |
| 1317 | // Compute magic values to divide by KW, KW*KH, OW, OW*OH, s0, s1 |
| 1318 | init_fastdiv_values(d: p.KW, mp&: p.KWmp, L&: p.KWL); |
| 1319 | init_fastdiv_values(d: p.KW*p.KH, mp&: p.KWKHmp, L&: p.KWKHL); |
| 1320 | init_fastdiv_values(d: p.OW, mp&: p.OWmp, L&: p.OWL); |
| 1321 | init_fastdiv_values(d: p.OW*p.OH, mp&: p.OWOHmp, L&: p.OWOHL); |
| 1322 | init_fastdiv_values(d: p.s0, mp&: p.s0mp, L&: p.s0L); |
| 1323 | init_fastdiv_values(d: p.s1, mp&: p.s1mp, L&: p.s1L); |
| 1324 | } |
| 1325 | |
| 1326 | struct vk_op_conv2d_dw_push_constants { |
| 1327 | uint32_t ne; |
| 1328 | uint32_t batches; |
| 1329 | uint32_t channels; |
| 1330 | uint32_t dst_w; |
| 1331 | uint32_t dst_h; |
| 1332 | uint32_t src_w; |
| 1333 | uint32_t src_h; |
| 1334 | uint32_t knl_w; |
| 1335 | uint32_t knl_h; |
| 1336 | int32_t stride_x; |
| 1337 | int32_t stride_y; |
| 1338 | int32_t pad_x; |
| 1339 | int32_t pad_y; |
| 1340 | int32_t dilation_x; |
| 1341 | int32_t dilation_y; |
| 1342 | }; |
| 1343 | |
| 1344 | struct vk_op_upscale_push_constants { |
| 1345 | uint32_t ne; uint32_t a_offset; uint32_t d_offset; |
| 1346 | uint32_t ne00; uint32_t ne01; |
| 1347 | uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03; |
| 1348 | uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; |
| 1349 | float sf0; float sf1; float sf2; float sf3; |
| 1350 | float pixel_offset; |
| 1351 | }; |
| 1352 | |
| 1353 | struct vk_op_sum_rows_push_constants |
| 1354 | { |
| 1355 | uint32_t n_cols; |
| 1356 | uint32_t ne01, ne02; |
| 1357 | uint32_t nb01, nb02, nb03; |
| 1358 | uint32_t nb11, nb12, nb13; |
| 1359 | float weight; |
| 1360 | uint32_t misalign_offsets; |
| 1361 | uint32_t ne0_12mp, ne0_12L; |
| 1362 | uint32_t ne0_1mp, ne0_1L; |
| 1363 | }; |
| 1364 | |
| 1365 | static vk_op_sum_rows_push_constants vk_op_sum_rows_push_constants_init(const ggml_tensor * src, const ggml_tensor * dst, int64_t n_cols) { |
| 1366 | uint32_t type_size = (uint32_t)ggml_type_size(type: src->type); |
| 1367 | vk_op_sum_rows_push_constants p = {}; |
| 1368 | p.n_cols = (uint32_t)n_cols; |
| 1369 | p.ne01 = (uint32_t)src->ne[1]; |
| 1370 | p.ne02 = (uint32_t)src->ne[2]; |
| 1371 | p.nb01 = (uint32_t)src->nb[1] / type_size; |
| 1372 | p.nb02 = (uint32_t)src->nb[2] / type_size; |
| 1373 | p.nb03 = (uint32_t)src->nb[3] / type_size; |
| 1374 | p.nb11 = (uint32_t)dst->nb[1] / type_size; |
| 1375 | p.nb12 = (uint32_t)dst->nb[2] / type_size; |
| 1376 | p.nb13 = (uint32_t)dst->nb[3] / type_size; |
| 1377 | p.weight = 1.0f; |
| 1378 | return p; |
| 1379 | } |
| 1380 | |
| 1381 | template <> void init_pushconst_fastdiv(vk_op_sum_rows_push_constants &p) { |
| 1382 | init_fastdiv_values(d: p.ne01*p.ne02, mp&: p.ne0_12mp, L&: p.ne0_12L); |
| 1383 | init_fastdiv_values(d: p.ne01, mp&: p.ne0_1mp, L&: p.ne0_1L); |
| 1384 | } |
| 1385 | |
| 1386 | // Allow pre-recording command buffers |
| 1387 | struct vk_staging_memcpy { |
| 1388 | vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {} |
| 1389 | |
| 1390 | void * dst; |
| 1391 | const void * src; |
| 1392 | size_t n; |
| 1393 | }; |
| 1394 | |
| 1395 | struct vk_staging_memset { |
| 1396 | vk_staging_memset(void * _dst, uint32_t _val, size_t _n) : dst(_dst), val(_val), n(_n) {} |
| 1397 | |
| 1398 | void * dst; |
| 1399 | uint32_t val; |
| 1400 | size_t n; |
| 1401 | }; |
| 1402 | |
| 1403 | struct vk_context_struct { |
| 1404 | vk_submission * s; |
| 1405 | std::vector<vk_sequence> seqs; |
| 1406 | |
| 1407 | int exit_tensor_idx; |
| 1408 | |
| 1409 | std::vector<vk_staging_memcpy> in_memcpys; |
| 1410 | std::vector<vk_staging_memcpy> out_memcpys; |
| 1411 | std::vector<vk_staging_memset> memsets; |
| 1412 | |
| 1413 | vk_command_pool * p {}; |
| 1414 | }; |
| 1415 | typedef std::shared_ptr<vk_context_struct> vk_context; |
| 1416 | typedef std::weak_ptr<vk_context_struct> vk_context_ref; |
| 1417 | |
| 1418 | struct ggml_vk_garbage_collector { |
| 1419 | std::vector<vk_semaphore> tl_semaphores; |
| 1420 | std::vector<vk_semaphore> semaphores; |
| 1421 | std::vector<vk::Event> events; |
| 1422 | std::vector<vk_context> contexts; |
| 1423 | }; |
| 1424 | |
| 1425 | static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx, vk_context subctx); |
| 1426 | static void ggml_vk_load_shaders(vk_device& device); |
| 1427 | static void ggml_pipeline_allocate_descriptor_sets(ggml_backend_vk_context * ctx); |
| 1428 | |
| 1429 | #if defined(GGML_VULKAN_MEMORY_DEBUG) || defined(GGML_VULKAN_DEBUG) |
| 1430 | #define VK_LOG_MEMORY(msg) std::cerr << "ggml_vulkan memory: " << msg << std::endl |
| 1431 | |
| 1432 | static std::string format_size(size_t size) { |
| 1433 | const size_t kib = 1024; |
| 1434 | const size_t mib = kib * 1024; |
| 1435 | const size_t gib = mib * 1024; |
| 1436 | |
| 1437 | std::ostringstream oss; |
| 1438 | oss << std::fixed << std::setprecision(2); |
| 1439 | |
| 1440 | if (size >= gib) { |
| 1441 | oss << static_cast<double>(size) / gib << " GiB" ; |
| 1442 | } else if (size >= mib) { |
| 1443 | oss << static_cast<double>(size) / mib << " MiB" ; |
| 1444 | } else if (size >= kib) { |
| 1445 | oss << static_cast<double>(size) / kib << " KiB" ; |
| 1446 | } else { |
| 1447 | oss << size << " B" ; |
| 1448 | } |
| 1449 | |
| 1450 | return oss.str(); |
| 1451 | } |
| 1452 | |
| 1453 | class vk_memory_logger { |
| 1454 | public: |
| 1455 | vk_memory_logger(): total_device(0), total_host(0) {} |
| 1456 | void log_allocation(vk_buffer_ref buf_ref, size_t size); |
| 1457 | void log_deallocation(vk_buffer_ref buf_ref); |
| 1458 | |
| 1459 | private: |
| 1460 | std::map<vk::Buffer, size_t> allocations; // Track allocations |
| 1461 | size_t total_device; |
| 1462 | size_t total_host; |
| 1463 | }; |
| 1464 | #else |
| 1465 | #define VK_LOG_MEMORY(msg) ((void) 0) |
| 1466 | #endif // GGML_VULKAN_MEMORY_DEBUG |
| 1467 | |
| 1468 | class vk_perf_logger { |
| 1469 | public: |
| 1470 | void print_timings() { |
| 1471 | if (timings.empty()) { |
| 1472 | return; |
| 1473 | } |
| 1474 | uint64_t total_all_op_times = 0; |
| 1475 | std::cerr << "----------------\nVulkan Timings:" << std::endl; |
| 1476 | for (const auto & t : timings) { |
| 1477 | uint64_t total_op_times = 0; |
| 1478 | for (const auto & time : t.second) { |
| 1479 | total_op_times += time; |
| 1480 | } |
| 1481 | std::cerr << t.first << ": " << t.second.size() << " x " << (total_op_times / t.second.size() / 1000.0) |
| 1482 | << " us" ; |
| 1483 | |
| 1484 | // If we have as many flops entries as timing entries for the op, then compute and log the flops/S. |
| 1485 | auto it = flops.find(x: t.first); |
| 1486 | if (it != flops.end() && (it->second).size() == t.second.size()) { |
| 1487 | uint64_t total_op_flops = 0; |
| 1488 | for (const auto & elem : it->second) { |
| 1489 | total_op_flops += elem; |
| 1490 | } |
| 1491 | std::cerr << " (" |
| 1492 | << (double(total_op_flops) / (1000.0 * 1000.0 * 1000.0)) / |
| 1493 | (double(total_op_times) / (1000.0 * 1000.0 * 1000.0)) |
| 1494 | << " GFLOPS/s)" ; |
| 1495 | } |
| 1496 | |
| 1497 | total_all_op_times += total_op_times; |
| 1498 | |
| 1499 | std::cerr << std::endl; |
| 1500 | } |
| 1501 | |
| 1502 | if (timings.size() > 0) { |
| 1503 | std::cerr << "Total time: " << total_all_op_times / 1000.0 << " us." << std::endl; |
| 1504 | } |
| 1505 | |
| 1506 | timings.clear(); |
| 1507 | flops.clear(); |
| 1508 | } |
| 1509 | |
| 1510 | void log_timing(const ggml_tensor * node, uint64_t time) { |
| 1511 | if (node->op == GGML_OP_UNARY) { |
| 1512 | timings[ggml_unary_op_name(op: ggml_get_unary_op(tensor: node))].push_back(x: time); |
| 1513 | return; |
| 1514 | } |
| 1515 | if (node->op == GGML_OP_MUL_MAT || node->op == GGML_OP_MUL_MAT_ID) { |
| 1516 | const uint64_t m = node->src[0]->ne[1]; |
| 1517 | const uint64_t n = node->ne[1]; |
| 1518 | const uint64_t k = node->src[1]->ne[0]; |
| 1519 | const uint64_t batch = node->src[1]->ne[2] * node->src[1]->ne[3]; |
| 1520 | std::string name = ggml_op_name(op: node->op); |
| 1521 | if ((node->op == GGML_OP_MUL_MAT && n <= mul_mat_vec_max_cols) || |
| 1522 | (node->op == GGML_OP_MUL_MAT_ID && node->src[2]->ne[1] == 1)) { |
| 1523 | name += "_VEC" ; |
| 1524 | } |
| 1525 | name += " " ; |
| 1526 | name += ggml_type_name(type: node->src[0]->type); |
| 1527 | name += " m=" + std::to_string(val: m) + " n=" + std::to_string(val: n) + " k=" + std::to_string(val: k); |
| 1528 | if (batch > 1) { |
| 1529 | name += " batch=" + std::to_string(val: batch); |
| 1530 | } |
| 1531 | timings[name].push_back(x: time); |
| 1532 | flops[name].push_back(x: m * n * (k + (k - 1)) * batch); |
| 1533 | return; |
| 1534 | } |
| 1535 | if (node->op == GGML_OP_CONV_2D || node->op == GGML_OP_CONV_TRANSPOSE_2D) { |
| 1536 | std::string name = ggml_op_name(op: node->op); |
| 1537 | ggml_tensor * knl = node->src[0]; |
| 1538 | uint64_t OW = node->ne[0]; |
| 1539 | uint64_t OH = node->ne[1]; |
| 1540 | uint64_t N = node->ne[3]; |
| 1541 | uint64_t Cout = node->ne[2]; |
| 1542 | uint64_t KW = knl->ne[0]; |
| 1543 | uint64_t KH = knl->ne[1]; |
| 1544 | uint64_t Cin = node->src[1]->ne[2]; |
| 1545 | // KxCRS @ CRSxNPQ = KxNPQ -> M=K, K=CRS, N=NPQ |
| 1546 | uint64_t size_M = Cout; |
| 1547 | uint64_t size_K = Cin * KW * KH; |
| 1548 | uint64_t size_N = N * OW * OH; |
| 1549 | uint64_t n_flops = size_M * size_N * (size_K + (size_K - 1)); |
| 1550 | name += " M=Cout=" + std::to_string(val: size_M) + ", K=Cin*KW*KH=" + std::to_string(val: size_K) + |
| 1551 | ", N=N*OW*OH=" + std::to_string(val: size_N); |
| 1552 | flops[name].push_back(x: n_flops); |
| 1553 | timings[name].push_back(x: time); |
| 1554 | return; |
| 1555 | } |
| 1556 | if (node->op == GGML_OP_RMS_NORM) { |
| 1557 | std::string name = ggml_op_name(op: node->op); |
| 1558 | name += "(" + std::to_string(val: node->ne[0]) + "," + std::to_string(val: node->ne[1]) + "," + std::to_string(val: node->ne[2]) + "," + std::to_string(val: node->ne[3]) + ")" ; |
| 1559 | timings[name].push_back(x: time); |
| 1560 | return; |
| 1561 | } |
| 1562 | timings[ggml_op_name(op: node->op)].push_back(x: time); |
| 1563 | } |
| 1564 | private: |
| 1565 | std::map<std::string, std::vector<uint64_t>> timings; |
| 1566 | std::map<std::string, std::vector<uint64_t>> flops; |
| 1567 | }; |
| 1568 | |
| 1569 | struct ggml_backend_vk_context { |
| 1570 | std::string name; |
| 1571 | |
| 1572 | vk_device device; |
| 1573 | |
| 1574 | size_t semaphore_idx, event_idx; |
| 1575 | ggml_vk_garbage_collector gc; |
| 1576 | size_t prealloc_size_x, prealloc_size_y, prealloc_size_split_k, prealloc_size_add_rms_partials, prealloc_size_add_rms_partials_offset; |
| 1577 | vk_buffer prealloc_x, prealloc_y, prealloc_split_k, prealloc_add_rms_partials; |
| 1578 | vk::Fence fence, almost_ready_fence; |
| 1579 | bool almost_ready_fence_pending {}; |
| 1580 | // Set before op_add and unset after op_rms_norm to indicate that the add should |
| 1581 | // write partial sums to accumulate the square of the vector components |
| 1582 | bool do_add_rms_partials_offset_calculation; |
| 1583 | bool do_add_rms_partials; |
| 1584 | |
| 1585 | uint64_t last_total_mul_mat_bytes {}; |
| 1586 | |
| 1587 | // Cache most recent tensor that was converted into prealloc_y, and what pipeline it used to convert. |
| 1588 | vk_pipeline_struct * prealloc_y_last_pipeline_used {}; |
| 1589 | const ggml_tensor * prealloc_y_last_tensor_used {}; |
| 1590 | |
| 1591 | // Track which nodes have been used since the last sync, and whether they were written to |
| 1592 | std::vector<const ggml_tensor *> unsynced_nodes_written; |
| 1593 | std::vector<const ggml_tensor *> unsynced_nodes_read; |
| 1594 | // Track which prealloc buffers have pending reads that need to be synchronized. |
| 1595 | // These are checked before writing to the buffer (and call ggml_vk_sync_buffers if set), |
| 1596 | // and set to true after the buffer contents are consumed. |
| 1597 | bool prealloc_x_need_sync, prealloc_y_need_sync, prealloc_split_k_need_sync; |
| 1598 | |
| 1599 | vk_context_ref compute_ctx; |
| 1600 | vk_context_ref transfer_ctx; |
| 1601 | |
| 1602 | std::vector<vk_context_ref> tensor_ctxs; |
| 1603 | |
| 1604 | std::vector<vk::DescriptorPool> descriptor_pools; |
| 1605 | std::vector<vk::DescriptorSet> descriptor_sets; |
| 1606 | uint32_t descriptor_set_idx {}; |
| 1607 | uint32_t pipeline_descriptor_set_requirements {}; |
| 1608 | |
| 1609 | vk_command_pool compute_cmd_pool; |
| 1610 | vk_command_pool transfer_cmd_pool; |
| 1611 | |
| 1612 | // number of additional consecutive nodes that are being fused with the |
| 1613 | // node currently being processed |
| 1614 | int num_additional_fused_ops {}; |
| 1615 | // Bitmask of which fused ops need to write an intermediate value to memory. |
| 1616 | // Bit 'i' means nodes[start_of_fusion + i] writes to memory. |
| 1617 | // If there's no fusion, bit 0 is still set. |
| 1618 | int fused_ops_write_mask {}; |
| 1619 | }; |
| 1620 | |
| 1621 | static void * const vk_ptr_base = (void *)(uintptr_t) 0x1000; // NOLINT |
| 1622 | |
| 1623 | static uint64_t vk_tensor_offset(const ggml_tensor * tensor) { |
| 1624 | if (tensor->view_src) { |
| 1625 | return (uint8_t *) tensor->view_src->data - (uint8_t *) vk_ptr_base; |
| 1626 | } |
| 1627 | return (uint8_t *) tensor->data - (uint8_t *) vk_ptr_base; |
| 1628 | } |
| 1629 | |
| 1630 | struct ggml_backend_vk_buffer_context { |
| 1631 | vk_device_ref device; |
| 1632 | vk_buffer dev_buffer; |
| 1633 | std::string name; |
| 1634 | |
| 1635 | ggml_backend_vk_buffer_context(vk_device_ref device, vk_buffer&& dev_buffer, std::string& name) : |
| 1636 | device(device), |
| 1637 | dev_buffer(dev_buffer), |
| 1638 | name(name) { |
| 1639 | } |
| 1640 | |
| 1641 | ~ggml_backend_vk_buffer_context() { |
| 1642 | ggml_vk_destroy_buffer(buf&: dev_buffer); |
| 1643 | } |
| 1644 | }; |
| 1645 | |
| 1646 | #ifdef GGML_VULKAN_MEMORY_DEBUG |
| 1647 | static std::mutex log_mutex; |
| 1648 | |
| 1649 | void vk_memory_logger::log_allocation(vk_buffer_ref buf_ref, size_t size) { |
| 1650 | std::lock_guard<std::mutex> guard(log_mutex); |
| 1651 | vk_buffer buf = buf_ref.lock(); |
| 1652 | const bool device = bool(buf->memory_property_flags & vk::MemoryPropertyFlagBits::eDeviceLocal); |
| 1653 | const std::string type = device ? "device" : "host" ; |
| 1654 | allocations[buf->buffer] = size; |
| 1655 | total_device += device ? size : 0; |
| 1656 | total_host += device ? 0 : size; |
| 1657 | VK_LOG_MEMORY(buf->device->name << ": +" << format_size(size) << " " << type << " at " << buf->buffer << ". Total device: " << format_size(total_device) << ", total host: " << format_size(total_host)); |
| 1658 | } |
| 1659 | |
| 1660 | void vk_memory_logger::log_deallocation(vk_buffer_ref buf_ref) { |
| 1661 | if (buf_ref.expired() || buf_ref.lock()->size == 0) { |
| 1662 | return; |
| 1663 | } |
| 1664 | |
| 1665 | std::lock_guard<std::mutex> guard(log_mutex); |
| 1666 | vk_buffer buf = buf_ref.lock(); |
| 1667 | const bool device = bool(buf->memory_property_flags & vk::MemoryPropertyFlagBits::eDeviceLocal); |
| 1668 | std::string type = device ? "device" : "host" ; |
| 1669 | auto it = allocations.find(buf->buffer); |
| 1670 | total_device -= device ? it->second : 0; |
| 1671 | total_host -= device ? 0 : it->second; |
| 1672 | if (it != allocations.end()) { |
| 1673 | VK_LOG_MEMORY(buf->device->name << ": -" << format_size(it->second) << " " << type << " at " << buf->buffer << ". Total device: " << format_size(total_device) << ", total host: " << format_size(total_host)); |
| 1674 | allocations.erase(it); |
| 1675 | } else { |
| 1676 | VK_LOG_MEMORY("ERROR " << buf->device->name << ": Attempted to deallocate unknown " << type << " memory at " << buf->buffer); |
| 1677 | } |
| 1678 | } |
| 1679 | #endif // GGML_VULKAN_MEMORY_DEBUG |
| 1680 | |
| 1681 | struct vk_instance_t { |
| 1682 | vk::Instance instance; |
| 1683 | |
| 1684 | bool debug_utils_support = false; // VK_EXT_debug_utils enabled |
| 1685 | PFN_vkSetDebugUtilsObjectNameEXT pfn_vkSetDebugUtilsObjectNameEXT = {}; |
| 1686 | PFN_vkQueueBeginDebugUtilsLabelEXT pfn_vkQueueBeginDebugUtilsLabelEXT = {}; |
| 1687 | PFN_vkQueueEndDebugUtilsLabelEXT pfn_vkQueueEndDebugUtilsLabelEXT = {}; |
| 1688 | PFN_vkCmdBeginDebugUtilsLabelEXT pfn_vkCmdBeginDebugUtilsLabelEXT = {}; |
| 1689 | PFN_vkCmdEndDebugUtilsLabelEXT pfn_vkCmdEndDebugUtilsLabelEXT = {}; |
| 1690 | PFN_vkCmdInsertDebugUtilsLabelEXT pfn_vkCmdInsertDebugUtilsLabelEXT = {}; |
| 1691 | |
| 1692 | std::vector<size_t> device_indices; |
| 1693 | std::vector<bool> device_supports_membudget; |
| 1694 | vk_device devices[GGML_VK_MAX_DEVICES]; |
| 1695 | }; |
| 1696 | |
| 1697 | static bool vk_instance_initialized = false; |
| 1698 | static vk_instance_t vk_instance; |
| 1699 | |
| 1700 | static bool vk_perf_logger_enabled = false; |
| 1701 | |
| 1702 | #ifdef GGML_VULKAN_CHECK_RESULTS |
| 1703 | static size_t vk_skip_checks; |
| 1704 | static size_t vk_output_tensor; |
| 1705 | |
| 1706 | static void ggml_vk_print_tensor(const ggml_tensor * tensor, const char * name); |
| 1707 | static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx); |
| 1708 | static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx); |
| 1709 | #endif |
| 1710 | |
| 1711 | typedef void (*ggml_vk_func_t)(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst); |
| 1712 | |
| 1713 | static void ggml_backend_vk_free(ggml_backend_t backend); |
| 1714 | |
| 1715 | static VkDeviceSize ggml_vk_get_max_buffer_range(const ggml_backend_vk_context * ctx, const vk_buffer &buf, const VkDeviceSize offset) { |
| 1716 | const VkDeviceSize range = std::min(a: VkDeviceSize{buf->size - offset}, |
| 1717 | b: VkDeviceSize{ctx->device->properties.limits.maxStorageBufferRange}); |
| 1718 | return range; |
| 1719 | } |
| 1720 | |
| 1721 | // Wait for ctx->fence to be signaled. |
| 1722 | static void ggml_vk_wait_for_fence(ggml_backend_vk_context * ctx) { |
| 1723 | // Use waitForFences while most of the graph executes. Hopefully the CPU can sleep |
| 1724 | // during this wait. |
| 1725 | if (ctx->almost_ready_fence_pending) { |
| 1726 | VK_CHECK(ctx->device->device.waitForFences({ ctx->almost_ready_fence }, true, UINT64_MAX), "almost_ready_fence" ); |
| 1727 | ctx->device->device.resetFences(fences: { ctx->almost_ready_fence }); |
| 1728 | ctx->almost_ready_fence_pending = false; |
| 1729 | } |
| 1730 | |
| 1731 | // Spin (w/pause) waiting for the graph to finish executing. |
| 1732 | vk::Result result; |
| 1733 | while ((result = ctx->device->device.getFenceStatus(fence: ctx->fence)) != vk::Result::eSuccess) { |
| 1734 | if (result != vk::Result::eNotReady) { |
| 1735 | fprintf(stderr, format: "ggml_vulkan: error %s at %s:%d\n" , to_string(value: result).c_str(), __FILE__, __LINE__); |
| 1736 | exit(status: 1); |
| 1737 | } |
| 1738 | for (uint32_t i = 0; i < 100; ++i) { |
| 1739 | YIELD(); |
| 1740 | YIELD(); |
| 1741 | YIELD(); |
| 1742 | YIELD(); |
| 1743 | YIELD(); |
| 1744 | YIELD(); |
| 1745 | YIELD(); |
| 1746 | YIELD(); |
| 1747 | YIELD(); |
| 1748 | YIELD(); |
| 1749 | } |
| 1750 | } |
| 1751 | ctx->device->device.resetFences(fences: { ctx->fence }); |
| 1752 | } |
| 1753 | |
| 1754 | // variables to track number of compiles in progress |
| 1755 | static uint32_t compile_count = 0; |
| 1756 | static std::mutex compile_count_mutex; |
| 1757 | static std::condition_variable compile_count_cond; |
| 1758 | |
| 1759 | static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipeline, size_t spv_size, const void* spv_data, const std::string entrypoint, |
| 1760 | uint32_t parameter_count, std::array<uint32_t, 3> wg_denoms, std::vector<uint32_t> specialization_constants, |
| 1761 | bool disable_robustness, bool require_full_subgroups, uint32_t required_subgroup_size) { |
| 1762 | VK_LOG_DEBUG("ggml_vk_create_pipeline(" << device->name << ", " << pipeline->name << ", " << entrypoint << ", " << parameter_count << |
| 1763 | ", (" << wg_denoms[0] << "," << wg_denoms[1] << "," << wg_denoms[2] << "), specialization_constants, " << |
| 1764 | disable_robustness << ", " << require_full_subgroups << ", " << required_subgroup_size << ")" ); |
| 1765 | GGML_ASSERT(parameter_count > 0); |
| 1766 | GGML_ASSERT(parameter_count <= MAX_PARAMETER_COUNT); |
| 1767 | GGML_ASSERT(wg_denoms[0] > 0 && wg_denoms[1] > 0 && wg_denoms[2] > 0); // NOLINT |
| 1768 | |
| 1769 | vk::ShaderModuleCreateInfo shader_module_create_info({}, spv_size, reinterpret_cast<const uint32_t *>(spv_data)); |
| 1770 | pipeline->shader_module = device->device.createShaderModule(createInfo: shader_module_create_info); |
| 1771 | |
| 1772 | vk::PushConstantRange pcr( |
| 1773 | vk::ShaderStageFlagBits::eCompute, |
| 1774 | 0, |
| 1775 | pipeline->push_constant_size |
| 1776 | ); |
| 1777 | |
| 1778 | vk::PipelineLayoutCreateInfo pipeline_layout_create_info(vk::PipelineLayoutCreateFlags(), device->dsl, pcr); |
| 1779 | pipeline->layout = device->device.createPipelineLayout(createInfo: pipeline_layout_create_info); |
| 1780 | |
| 1781 | std::vector<vk::SpecializationMapEntry> specialization_entries(specialization_constants.size()); |
| 1782 | |
| 1783 | for (size_t i = 0; i < specialization_constants.size(); i++) { |
| 1784 | specialization_entries[i].constantID = i; |
| 1785 | specialization_entries[i].offset = i * sizeof(uint32_t); |
| 1786 | specialization_entries[i].size = sizeof(uint32_t); |
| 1787 | } |
| 1788 | |
| 1789 | vk::SpecializationInfo specialization_info( |
| 1790 | specialization_entries.size(), |
| 1791 | specialization_entries.data(), |
| 1792 | specialization_constants.size() * sizeof(uint32_t), |
| 1793 | specialization_constants.data() |
| 1794 | ); |
| 1795 | |
| 1796 | vk::PipelineShaderStageCreateFlags pipeline_shader_stage_create_flags{}; |
| 1797 | |
| 1798 | if (device->subgroup_require_full_support && require_full_subgroups) { |
| 1799 | pipeline_shader_stage_create_flags |= vk::PipelineShaderStageCreateFlagBits::eRequireFullSubgroupsEXT; |
| 1800 | } |
| 1801 | |
| 1802 | vk::PipelineShaderStageCreateInfo pipeline_shader_create_info( |
| 1803 | pipeline_shader_stage_create_flags, |
| 1804 | vk::ShaderStageFlagBits::eCompute, |
| 1805 | pipeline->shader_module, |
| 1806 | entrypoint.c_str(), |
| 1807 | &specialization_info); |
| 1808 | |
| 1809 | vk::PipelineShaderStageRequiredSubgroupSizeCreateInfoEXT pipeline_shader_stage_required_subgroup_size_create_info; |
| 1810 | pipeline_shader_stage_required_subgroup_size_create_info.requiredSubgroupSize = required_subgroup_size; |
| 1811 | if (device->subgroup_size_control && required_subgroup_size > 0) { |
| 1812 | GGML_ASSERT(device->subgroup_min_size <= required_subgroup_size && required_subgroup_size <= device->subgroup_max_size); |
| 1813 | pipeline_shader_create_info.setPNext(&pipeline_shader_stage_required_subgroup_size_create_info); |
| 1814 | } |
| 1815 | |
| 1816 | vk::ComputePipelineCreateInfo compute_pipeline_create_info( |
| 1817 | device->pipeline_executable_properties_support ? |
| 1818 | vk::PipelineCreateFlagBits::eCaptureStatisticsKHR : |
| 1819 | vk::PipelineCreateFlags{}, |
| 1820 | pipeline_shader_create_info, |
| 1821 | pipeline->layout); |
| 1822 | |
| 1823 | vk::PipelineRobustnessCreateInfoEXT rci; |
| 1824 | |
| 1825 | if (device->pipeline_robustness && disable_robustness) { |
| 1826 | rci.storageBuffers = vk::PipelineRobustnessBufferBehaviorEXT::eDisabled; |
| 1827 | rci.uniformBuffers = vk::PipelineRobustnessBufferBehaviorEXT::eDisabled; |
| 1828 | compute_pipeline_create_info.setPNext(&rci); |
| 1829 | } |
| 1830 | |
| 1831 | try { |
| 1832 | pipeline->pipeline = device->device.createComputePipeline(VK_NULL_HANDLE, createInfo: compute_pipeline_create_info).value; |
| 1833 | } catch (const vk::SystemError& e) { |
| 1834 | std::cerr << "ggml_vulkan: Compute pipeline creation failed for " << pipeline->name << std::endl; |
| 1835 | std::cerr << "ggml_vulkan: " << e.what() << std::endl; |
| 1836 | throw e; |
| 1837 | } |
| 1838 | pipeline->compiled = true; |
| 1839 | |
| 1840 | if (vk_instance.debug_utils_support) { |
| 1841 | vk::DebugUtilsObjectNameInfoEXT duoni; |
| 1842 | duoni.objectType = vk::ObjectType::ePipeline; |
| 1843 | duoni.pObjectName = pipeline->name.c_str(); |
| 1844 | duoni.objectHandle = /*reinterpret_cast*/(uint64_t)(static_cast<VkPipeline>(pipeline->pipeline)); |
| 1845 | vk_instance.pfn_vkSetDebugUtilsObjectNameEXT(device->device, &static_cast<VkDebugUtilsObjectNameInfoEXT &>(duoni)); |
| 1846 | } |
| 1847 | |
| 1848 | if (device->pipeline_executable_properties_support) { |
| 1849 | vk::PipelineExecutableInfoKHR executableInfo; |
| 1850 | executableInfo.pipeline = pipeline->pipeline; |
| 1851 | |
| 1852 | auto statistics = device->device.getPipelineExecutableStatisticsKHR(executableInfo); |
| 1853 | for (auto & s : statistics) { |
| 1854 | // "Register Count" is reported by NVIDIA drivers. |
| 1855 | if (strcmp(s1: s.name, s2: "Register Count" ) == 0) { |
| 1856 | VK_LOG_DEBUG(pipeline->name << " " << s.name << ": " << s.value.u64 << " registers" ); |
| 1857 | pipeline->register_count = (uint32_t)s.value.u64; |
| 1858 | } |
| 1859 | } |
| 1860 | } |
| 1861 | |
| 1862 | device->all_pipelines.push_back(x: pipeline); |
| 1863 | |
| 1864 | { |
| 1865 | std::lock_guard<std::mutex> guard(compile_count_mutex); |
| 1866 | assert(compile_count > 0); |
| 1867 | compile_count--; |
| 1868 | } |
| 1869 | compile_count_cond.notify_all(); |
| 1870 | } |
| 1871 | |
| 1872 | static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline) { |
| 1873 | VK_LOG_DEBUG("ggml_pipeline_destroy_pipeline(" << pipeline->name << ")" ); |
| 1874 | device.destroyPipelineLayout(pipelineLayout: pipeline->layout); |
| 1875 | |
| 1876 | device.destroyShaderModule(shaderModule: pipeline->shader_module); |
| 1877 | |
| 1878 | device.destroyPipeline(pipeline: pipeline->pipeline); |
| 1879 | } |
| 1880 | |
| 1881 | static void ggml_pipeline_request_descriptor_sets(ggml_backend_vk_context *ctx, vk_pipeline& pipeline, uint32_t n) { |
| 1882 | VK_LOG_DEBUG("ggml_pipeline_request_descriptor_sets(" << pipeline->name << ", " << n << ")" ); |
| 1883 | ctx->pipeline_descriptor_set_requirements += n; |
| 1884 | if (!pipeline->compiled) { |
| 1885 | pipeline->needed = true; |
| 1886 | ggml_vk_load_shaders(device&: ctx->device); |
| 1887 | } |
| 1888 | ggml_pipeline_allocate_descriptor_sets(ctx); |
| 1889 | } |
| 1890 | |
| 1891 | static void ggml_pipeline_allocate_descriptor_sets(ggml_backend_vk_context * ctx) { |
| 1892 | |
| 1893 | if (ctx->descriptor_sets.size() >= ctx->pipeline_descriptor_set_requirements) { |
| 1894 | // Enough descriptors are available |
| 1895 | return; |
| 1896 | } |
| 1897 | |
| 1898 | vk_device& device = ctx->device; |
| 1899 | |
| 1900 | // Grow by 50% to avoid frequent allocations |
| 1901 | uint32_t needed = std::max(a: 3 * ctx->descriptor_sets.size() / 2, b: size_t{ctx->pipeline_descriptor_set_requirements}); |
| 1902 | uint32_t to_alloc = needed - ctx->descriptor_sets.size(); |
| 1903 | uint32_t pool_remaining = VK_DEVICE_DESCRIPTOR_POOL_SIZE - ctx->descriptor_sets.size() % VK_DEVICE_DESCRIPTOR_POOL_SIZE; |
| 1904 | uint32_t pool_idx = ctx->descriptor_sets.size() / VK_DEVICE_DESCRIPTOR_POOL_SIZE; |
| 1905 | |
| 1906 | while (to_alloc > 0) { |
| 1907 | const uint32_t alloc_count = std::min(a: pool_remaining, b: to_alloc); |
| 1908 | to_alloc -= alloc_count; |
| 1909 | pool_remaining = VK_DEVICE_DESCRIPTOR_POOL_SIZE; |
| 1910 | |
| 1911 | if (pool_idx >= ctx->descriptor_pools.size()) { |
| 1912 | vk::DescriptorPoolSize descriptor_pool_size(vk::DescriptorType::eStorageBuffer, MAX_PARAMETER_COUNT * VK_DEVICE_DESCRIPTOR_POOL_SIZE); |
| 1913 | vk::DescriptorPoolCreateInfo descriptor_pool_create_info({}, VK_DEVICE_DESCRIPTOR_POOL_SIZE, descriptor_pool_size); |
| 1914 | ctx->descriptor_pools.push_back(x: device->device.createDescriptorPool(createInfo: descriptor_pool_create_info)); |
| 1915 | } |
| 1916 | |
| 1917 | std::vector<vk::DescriptorSetLayout> layouts(alloc_count); |
| 1918 | for (uint32_t i = 0; i < alloc_count; i++) { |
| 1919 | layouts[i] = device->dsl; |
| 1920 | } |
| 1921 | vk::DescriptorSetAllocateInfo descriptor_set_alloc_info(ctx->descriptor_pools[pool_idx], alloc_count, layouts.data()); |
| 1922 | std::vector<vk::DescriptorSet> sets = device->device.allocateDescriptorSets(allocateInfo: descriptor_set_alloc_info); |
| 1923 | ctx->descriptor_sets.insert(position: ctx->descriptor_sets.end(), first: sets.begin(), last: sets.end()); |
| 1924 | |
| 1925 | pool_idx++; |
| 1926 | } |
| 1927 | } |
| 1928 | |
| 1929 | static vk::CommandBuffer ggml_vk_create_cmd_buffer(vk_device& device, vk_command_pool& p) { |
| 1930 | VK_LOG_DEBUG("ggml_vk_create_cmd_buffer()" ); |
| 1931 | |
| 1932 | if (p.cmd_buffers.size() > p.cmd_buffer_idx) { |
| 1933 | // Reuse command buffer |
| 1934 | return p.cmd_buffers[p.cmd_buffer_idx++]; |
| 1935 | } |
| 1936 | |
| 1937 | vk::CommandBufferAllocateInfo command_buffer_alloc_info( |
| 1938 | p.pool, |
| 1939 | vk::CommandBufferLevel::ePrimary, |
| 1940 | 1); |
| 1941 | const std::vector<vk::CommandBuffer> cmd_buffers = device->device.allocateCommandBuffers(allocateInfo: command_buffer_alloc_info); |
| 1942 | auto buf = cmd_buffers.front(); |
| 1943 | |
| 1944 | p.cmd_buffers.push_back(x: buf); |
| 1945 | p.cmd_buffer_idx++; |
| 1946 | |
| 1947 | return buf; |
| 1948 | } |
| 1949 | |
| 1950 | static void ggml_vk_submit(vk_context& ctx, vk::Fence fence) { |
| 1951 | if (ctx->seqs.empty()) { |
| 1952 | if (fence) { |
| 1953 | std::lock_guard<std::mutex> guard(queue_mutex); |
| 1954 | ctx->p->q->queue.submit(submits: {}, fence); |
| 1955 | } |
| 1956 | return; |
| 1957 | } |
| 1958 | VK_LOG_DEBUG("ggml_vk_submit(" << ctx << ", " << fence << ")" ); |
| 1959 | |
| 1960 | std::vector<std::vector<uint64_t>> tl_wait_vals; |
| 1961 | std::vector<std::vector<uint64_t>> tl_signal_vals; |
| 1962 | std::vector<std::vector<vk::Semaphore>> tl_wait_semaphores; |
| 1963 | std::vector<std::vector<vk::Semaphore>> tl_signal_semaphores; |
| 1964 | std::vector<vk::TimelineSemaphoreSubmitInfo> tl_submit_infos; |
| 1965 | std::vector<vk::SubmitInfo> submit_infos; |
| 1966 | int idx = -1; |
| 1967 | std::vector<std::vector<vk::PipelineStageFlags>> stage_flags; |
| 1968 | |
| 1969 | size_t reserve = 0; |
| 1970 | |
| 1971 | for (const auto& sequence : ctx->seqs) { |
| 1972 | reserve += sequence.size(); |
| 1973 | } |
| 1974 | |
| 1975 | // Pre-reserve vectors to prevent reallocation, which invalidates pointers |
| 1976 | tl_wait_semaphores.reserve(n: reserve); |
| 1977 | tl_wait_vals.reserve(n: reserve); |
| 1978 | tl_signal_semaphores.reserve(n: reserve); |
| 1979 | tl_signal_vals.reserve(n: reserve); |
| 1980 | tl_submit_infos.reserve(n: reserve); |
| 1981 | submit_infos.reserve(n: reserve); |
| 1982 | stage_flags.reserve(n: reserve); |
| 1983 | |
| 1984 | for (const auto& sequence : ctx->seqs) { |
| 1985 | for (const auto& submission : sequence) { |
| 1986 | stage_flags.push_back(x: {}); |
| 1987 | idx++; |
| 1988 | tl_wait_vals.push_back(x: {}); |
| 1989 | tl_wait_semaphores.push_back(x: {}); |
| 1990 | tl_signal_vals.push_back(x: {}); |
| 1991 | tl_signal_semaphores.push_back(x: {}); |
| 1992 | for (size_t i = 0; i < submission.wait_semaphores.size(); i++) { |
| 1993 | stage_flags[idx].push_back(x: ctx->p->q->stage_flags); |
| 1994 | tl_wait_vals[idx].push_back(x: submission.wait_semaphores[i].value); |
| 1995 | tl_wait_semaphores[idx].push_back(x: submission.wait_semaphores[i].s); |
| 1996 | } |
| 1997 | for (size_t i = 0; i < submission.signal_semaphores.size(); i++) { |
| 1998 | tl_signal_vals[idx].push_back(x: submission.signal_semaphores[i].value); |
| 1999 | tl_signal_semaphores[idx].push_back(x: submission.signal_semaphores[i].s); |
| 2000 | } |
| 2001 | tl_submit_infos.push_back(x: { |
| 2002 | (uint32_t) submission.wait_semaphores.size(), |
| 2003 | tl_wait_vals[idx].data(), |
| 2004 | (uint32_t) submission.signal_semaphores.size(), |
| 2005 | tl_signal_vals[idx].data(), |
| 2006 | }); |
| 2007 | tl_submit_infos[idx].sType = vk::StructureType::eTimelineSemaphoreSubmitInfo; |
| 2008 | tl_submit_infos[idx].pNext = nullptr; |
| 2009 | vk::SubmitInfo si{ |
| 2010 | (uint32_t) submission.wait_semaphores.size(), |
| 2011 | tl_wait_semaphores[idx].data(), |
| 2012 | stage_flags[idx].data(), |
| 2013 | 1, |
| 2014 | &submission.buffer, |
| 2015 | (uint32_t) submission.signal_semaphores.size(), |
| 2016 | tl_signal_semaphores[idx].data(), |
| 2017 | }; |
| 2018 | si.setPNext(&tl_submit_infos[idx]); |
| 2019 | submit_infos.push_back(x: si); |
| 2020 | } |
| 2021 | } |
| 2022 | |
| 2023 | std::lock_guard<std::mutex> guard(queue_mutex); |
| 2024 | ctx->p->q->queue.submit(submits: submit_infos, fence); |
| 2025 | |
| 2026 | ctx->seqs.clear(); |
| 2027 | } |
| 2028 | |
| 2029 | static uint32_t ggml_vk_find_queue_family_index(std::vector<vk::QueueFamilyProperties>& queue_family_props, const vk::QueueFlags& required, const vk::QueueFlags& avoid, int32_t compute_index, uint32_t min_num_queues) { |
| 2030 | VK_LOG_DEBUG("ggml_vk_find_queue_family_index()" ); |
| 2031 | const uint32_t qfsize = queue_family_props.size(); |
| 2032 | |
| 2033 | // Try with avoid preferences first |
| 2034 | for (uint32_t i = 0; i < qfsize; i++) { |
| 2035 | if (queue_family_props[i].queueCount >= min_num_queues && (compute_index < 0 || i != (uint32_t) compute_index) && queue_family_props[i].queueFlags & required && !(queue_family_props[i].queueFlags & avoid)) { |
| 2036 | return i; |
| 2037 | } |
| 2038 | } |
| 2039 | |
| 2040 | // Fall back to only required |
| 2041 | for (size_t i = 0; i < qfsize; i++) { |
| 2042 | if (queue_family_props[i].queueCount >= min_num_queues && (compute_index < 0 || i != (uint32_t) compute_index) && queue_family_props[i].queueFlags & required) { |
| 2043 | return i; |
| 2044 | } |
| 2045 | } |
| 2046 | |
| 2047 | // Fall back to reusing compute queue |
| 2048 | for (size_t i = 0; i < qfsize; i++) { |
| 2049 | if (queue_family_props[i].queueCount >= min_num_queues && queue_family_props[i].queueFlags & required) { |
| 2050 | return i; |
| 2051 | } |
| 2052 | } |
| 2053 | |
| 2054 | // Fall back to ignoring min_num_queries |
| 2055 | for (size_t i = 0; i < qfsize; i++) { |
| 2056 | if (queue_family_props[i].queueFlags & required) { |
| 2057 | return i; |
| 2058 | } |
| 2059 | } |
| 2060 | |
| 2061 | // All commands that are allowed on a queue that supports transfer operations are also allowed on a queue that supports either graphics or compute operations. |
| 2062 | // Thus, if the capabilities of a queue family include VK_QUEUE_GRAPHICS_BIT or VK_QUEUE_COMPUTE_BIT, then reporting the VK_QUEUE_TRANSFER_BIT capability separately for that queue family is optional. |
| 2063 | if (compute_index >= 0) { |
| 2064 | return compute_index; |
| 2065 | } |
| 2066 | |
| 2067 | std::cerr << "ggml_vulkan: No suitable queue family index found." << std::endl; |
| 2068 | |
| 2069 | for(auto &q_family : queue_family_props) { |
| 2070 | std::cerr << "Queue number: " + std::to_string(val: q_family.queueCount) << " flags: " + to_string(value: q_family.queueFlags) << std::endl; |
| 2071 | } |
| 2072 | abort(); |
| 2073 | } |
| 2074 | |
| 2075 | static void ggml_vk_create_queue(vk_device& device, vk_queue& q, uint32_t queue_family_index, uint32_t queue_index, vk::PipelineStageFlags&& stage_flags, bool transfer_only) { |
| 2076 | VK_LOG_DEBUG("ggml_vk_create_queue()" ); |
| 2077 | std::lock_guard<std::recursive_mutex> guard(device->mutex); |
| 2078 | |
| 2079 | q.queue_family_index = queue_family_index; |
| 2080 | q.transfer_only = transfer_only; |
| 2081 | |
| 2082 | q.cmd_pool.init(device, q_: &q); |
| 2083 | |
| 2084 | q.queue = device->device.getQueue(queueFamilyIndex: queue_family_index, queueIndex: queue_index); |
| 2085 | |
| 2086 | q.stage_flags = stage_flags; |
| 2087 | } |
| 2088 | |
| 2089 | static vk_context ggml_vk_create_context(ggml_backend_vk_context * ctx, vk_command_pool& p) { |
| 2090 | vk_context result = std::make_shared<vk_context_struct>(); |
| 2091 | VK_LOG_DEBUG("ggml_vk_create_context(" << result << ")" ); |
| 2092 | ctx->gc.contexts.emplace_back(args&: result); |
| 2093 | result->p = &p; |
| 2094 | return result; |
| 2095 | } |
| 2096 | |
| 2097 | static vk_context ggml_vk_create_temporary_context(vk_command_pool& p) { |
| 2098 | vk_context result = std::make_shared<vk_context_struct>(); |
| 2099 | VK_LOG_DEBUG("ggml_vk_create_temporary_context(" << result << ")" ); |
| 2100 | result->p = &p; |
| 2101 | return result; |
| 2102 | } |
| 2103 | |
| 2104 | static vk_semaphore * ggml_vk_create_binary_semaphore(ggml_backend_vk_context * ctx) { |
| 2105 | VK_LOG_DEBUG("ggml_vk_create_timeline_semaphore()" ); |
| 2106 | vk::SemaphoreTypeCreateInfo tci{ vk::SemaphoreType::eBinary, 0 }; |
| 2107 | vk::SemaphoreCreateInfo ci{}; |
| 2108 | ci.setPNext(&tci); |
| 2109 | vk::Semaphore semaphore = ctx->device->device.createSemaphore(createInfo: ci); |
| 2110 | ctx->gc.semaphores.push_back(x: { .s: semaphore, .value: 0 }); |
| 2111 | return &ctx->gc.semaphores[ctx->gc.semaphores.size() - 1]; |
| 2112 | } |
| 2113 | |
| 2114 | static vk_semaphore * ggml_vk_create_timeline_semaphore(ggml_backend_vk_context * ctx) { |
| 2115 | VK_LOG_DEBUG("ggml_vk_create_timeline_semaphore()" ); |
| 2116 | if (ctx->semaphore_idx >= ctx->gc.tl_semaphores.size()) { |
| 2117 | vk::SemaphoreTypeCreateInfo tci{ vk::SemaphoreType::eTimeline, 0 }; |
| 2118 | vk::SemaphoreCreateInfo ci{}; |
| 2119 | ci.setPNext(&tci); |
| 2120 | vk::Semaphore semaphore = ctx->device->device.createSemaphore(createInfo: ci); |
| 2121 | ctx->gc.tl_semaphores.push_back(x: { .s: semaphore, .value: 0 }); |
| 2122 | } |
| 2123 | return &ctx->gc.tl_semaphores[ctx->semaphore_idx++]; |
| 2124 | } |
| 2125 | |
| 2126 | static vk::Event ggml_vk_create_event(ggml_backend_vk_context * ctx) { |
| 2127 | if (ctx->event_idx >= ctx->gc.events.size()) { |
| 2128 | ctx->gc.events.push_back(x: ctx->device->device.createEvent(createInfo: {})); |
| 2129 | } |
| 2130 | return ctx->gc.events[ctx->event_idx++]; |
| 2131 | } |
| 2132 | |
| 2133 | static void ggml_vk_command_pool_cleanup(vk_device& device, vk_command_pool& p) { |
| 2134 | VK_LOG_DEBUG("ggml_vk_command_pool_cleanup()" ); |
| 2135 | |
| 2136 | // Requires command buffers to be done |
| 2137 | device->device.resetCommandPool(commandPool: p.pool); |
| 2138 | p.cmd_buffer_idx = 0; |
| 2139 | } |
| 2140 | |
| 2141 | static void ggml_vk_queue_command_pools_cleanup(vk_device& device) { |
| 2142 | VK_LOG_DEBUG("ggml_vk_queue_command_pools_cleanup()" ); |
| 2143 | |
| 2144 | // Arbitrary frequency to cleanup/reuse command buffers |
| 2145 | static constexpr uint32_t cleanup_frequency = 10; |
| 2146 | |
| 2147 | if (device->compute_queue.cmd_pool.cmd_buffer_idx >= cleanup_frequency) { |
| 2148 | ggml_vk_command_pool_cleanup(device, p&: device->compute_queue.cmd_pool); |
| 2149 | } |
| 2150 | if (device->transfer_queue.cmd_pool.cmd_buffer_idx >= cleanup_frequency) { |
| 2151 | ggml_vk_command_pool_cleanup(device, p&: device->transfer_queue.cmd_pool); |
| 2152 | } |
| 2153 | } |
| 2154 | |
| 2155 | |
| 2156 | static uint32_t find_properties(const vk::PhysicalDeviceMemoryProperties* mem_props, vk::MemoryRequirements* mem_req, vk::MemoryPropertyFlags flags) { |
| 2157 | for (uint32_t i = 0; i < mem_props->memoryTypeCount; ++i) { |
| 2158 | vk::MemoryType memory_type = mem_props->memoryTypes[i]; |
| 2159 | if ((mem_req->memoryTypeBits & ((uint64_t)1 << i)) && |
| 2160 | (flags & memory_type.propertyFlags) == flags && |
| 2161 | mem_props->memoryHeaps[memory_type.heapIndex].size >= mem_req->size) { |
| 2162 | return static_cast<int32_t>(i); |
| 2163 | } |
| 2164 | } |
| 2165 | return UINT32_MAX; |
| 2166 | } |
| 2167 | |
| 2168 | static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std::initializer_list<vk::MemoryPropertyFlags> & req_flags_list) { |
| 2169 | VK_LOG_DEBUG("ggml_vk_create_buffer(" << device->name << ", " << size << ", " << to_string(req_flags_list.begin()[0]) << ", " << to_string(req_flags_list.begin()[req_flags_list.size()-1]) << ")" ); |
| 2170 | if (size > device->max_buffer_size) { |
| 2171 | throw vk::OutOfDeviceMemoryError("Requested buffer size exceeds device buffer size limit" ); |
| 2172 | } |
| 2173 | |
| 2174 | vk_buffer buf = std::make_shared<vk_buffer_struct>(); |
| 2175 | |
| 2176 | if (size == 0) { |
| 2177 | buf->size = 0; |
| 2178 | return buf; |
| 2179 | } |
| 2180 | |
| 2181 | vk::BufferUsageFlags usage_flags = vk::BufferUsageFlagBits::eStorageBuffer | vk::BufferUsageFlagBits::eTransferSrc | vk::BufferUsageFlagBits::eTransferDst; |
| 2182 | vk::MemoryAllocateFlags mem_flags {}; |
| 2183 | if (device->buffer_device_address) { |
| 2184 | usage_flags |= vk::BufferUsageFlagBits::eShaderDeviceAddress; |
| 2185 | mem_flags |= vk::MemoryAllocateFlagBits::eDeviceAddress; |
| 2186 | } |
| 2187 | |
| 2188 | vk::BufferCreateInfo buffer_create_info{ |
| 2189 | vk::BufferCreateFlags(), |
| 2190 | size, |
| 2191 | usage_flags, |
| 2192 | vk::SharingMode::eExclusive, |
| 2193 | 0, |
| 2194 | nullptr, |
| 2195 | }; |
| 2196 | |
| 2197 | buf->buffer = device->device.createBuffer(createInfo: buffer_create_info); |
| 2198 | |
| 2199 | vk::MemoryRequirements mem_req = device->device.getBufferMemoryRequirements(buffer: buf->buffer); |
| 2200 | |
| 2201 | vk::PhysicalDeviceMemoryProperties mem_props = device->physical_device.getMemoryProperties(); |
| 2202 | |
| 2203 | const vk::MemoryAllocateFlagsInfo mem_flags_info { mem_flags }; |
| 2204 | |
| 2205 | for (auto it = req_flags_list.begin(); it != req_flags_list.end(); it++) { |
| 2206 | const auto & req_flags = *it; |
| 2207 | |
| 2208 | uint32_t memory_type_index = find_properties(mem_props: &mem_props, mem_req: &mem_req, flags: req_flags); |
| 2209 | |
| 2210 | if (memory_type_index == UINT32_MAX) { |
| 2211 | continue; |
| 2212 | } |
| 2213 | buf->memory_property_flags = req_flags; |
| 2214 | |
| 2215 | try { |
| 2216 | buf->device_memory = device->device.allocateMemory(allocateInfo: { mem_req.size, memory_type_index, &mem_flags_info }); |
| 2217 | break; |
| 2218 | } catch (const vk::SystemError& e) { |
| 2219 | // loop and retry |
| 2220 | // during last attempt throw the exception |
| 2221 | if (it + 1 == req_flags_list.end()) { |
| 2222 | device->device.destroyBuffer(buffer: buf->buffer); |
| 2223 | throw e; |
| 2224 | } |
| 2225 | } |
| 2226 | } |
| 2227 | |
| 2228 | if (!buf->device_memory) { |
| 2229 | device->device.destroyBuffer(buffer: buf->buffer); |
| 2230 | throw vk::OutOfDeviceMemoryError("No suitable memory type found" ); |
| 2231 | } |
| 2232 | |
| 2233 | buf->ptr = nullptr; |
| 2234 | |
| 2235 | if (buf->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) { |
| 2236 | buf->ptr = device->device.mapMemory(memory: buf->device_memory, offset: 0, VK_WHOLE_SIZE); |
| 2237 | } |
| 2238 | |
| 2239 | device->device.bindBufferMemory(buffer: buf->buffer, memory: buf->device_memory, memoryOffset: 0); |
| 2240 | |
| 2241 | buf->device = device; |
| 2242 | buf->size = size; |
| 2243 | |
| 2244 | if (device->buffer_device_address) { |
| 2245 | const vk::BufferDeviceAddressInfo addressInfo(buf->buffer); |
| 2246 | buf->bda_addr = device->device.getBufferAddress(info: addressInfo); |
| 2247 | } |
| 2248 | |
| 2249 | #ifdef GGML_VULKAN_MEMORY_DEBUG |
| 2250 | device->memory_logger->log_allocation(buf, size); |
| 2251 | #endif |
| 2252 | |
| 2253 | return buf; |
| 2254 | } |
| 2255 | |
| 2256 | static vk_buffer ggml_vk_create_buffer_check(vk_device& device, size_t size, vk::MemoryPropertyFlags req_flags, vk::MemoryPropertyFlags fallback_flags = vk::MemoryPropertyFlags(0)) { |
| 2257 | try { |
| 2258 | return ggml_vk_create_buffer(device, size, req_flags_list: {req_flags, fallback_flags}); |
| 2259 | } catch (const vk::SystemError& e) { |
| 2260 | std::cerr << "ggml_vulkan: Memory allocation of size " << size << " failed." << std::endl; |
| 2261 | std::cerr << "ggml_vulkan: " << e.what() << std::endl; |
| 2262 | throw e; |
| 2263 | } |
| 2264 | } |
| 2265 | |
| 2266 | static vk_buffer ggml_vk_create_buffer_device(vk_device& device, size_t size) { |
| 2267 | vk_buffer buf; |
| 2268 | try { |
| 2269 | if (device->prefer_host_memory) { |
| 2270 | buf = ggml_vk_create_buffer(device, size, req_flags_list: {vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent, |
| 2271 | vk::MemoryPropertyFlagBits::eDeviceLocal}); |
| 2272 | } else if (device->uma) { |
| 2273 | // Fall back to host memory type |
| 2274 | buf = ggml_vk_create_buffer(device, size, req_flags_list: {vk::MemoryPropertyFlagBits::eDeviceLocal, |
| 2275 | vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent}); |
| 2276 | } else if (device->disable_host_visible_vidmem) { |
| 2277 | if (device->allow_sysmem_fallback) { |
| 2278 | buf = ggml_vk_create_buffer(device, size, req_flags_list: {vk::MemoryPropertyFlagBits::eDeviceLocal, |
| 2279 | vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent}); |
| 2280 | } else { |
| 2281 | buf = ggml_vk_create_buffer(device, size, req_flags_list: {vk::MemoryPropertyFlagBits::eDeviceLocal}); |
| 2282 | } |
| 2283 | } else { |
| 2284 | // use rebar if available, otherwise fallback to device only visible memory |
| 2285 | if (device->allow_sysmem_fallback) { |
| 2286 | buf = ggml_vk_create_buffer(device, size, req_flags_list: {vk::MemoryPropertyFlagBits::eDeviceLocal | vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent, |
| 2287 | vk::MemoryPropertyFlagBits::eDeviceLocal, |
| 2288 | vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent}); |
| 2289 | } else { |
| 2290 | buf = ggml_vk_create_buffer(device, size, req_flags_list: {vk::MemoryPropertyFlagBits::eDeviceLocal | vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent, |
| 2291 | vk::MemoryPropertyFlagBits::eDeviceLocal}); |
| 2292 | } |
| 2293 | } |
| 2294 | } catch (const vk::SystemError& e) { |
| 2295 | std::cerr << "ggml_vulkan: Device memory allocation of size " << size << " failed." << std::endl; |
| 2296 | std::cerr << "ggml_vulkan: " << e.what() << std::endl; |
| 2297 | throw e; |
| 2298 | } |
| 2299 | |
| 2300 | return buf; |
| 2301 | } |
| 2302 | |
| 2303 | static void ggml_vk_destroy_buffer(vk_buffer& buf) { |
| 2304 | if (buf == nullptr) { |
| 2305 | return; |
| 2306 | } |
| 2307 | |
| 2308 | #ifdef GGML_VULKAN_MEMORY_DEBUG |
| 2309 | if (buf->device != nullptr) { |
| 2310 | buf->device->memory_logger->log_deallocation(buf); |
| 2311 | } |
| 2312 | #endif |
| 2313 | |
| 2314 | buf.reset(); |
| 2315 | } |
| 2316 | |
| 2317 | static vk_subbuffer ggml_vk_subbuffer(const ggml_backend_vk_context* ctx, const vk_buffer& buf, size_t offset = 0) { |
| 2318 | return { .buffer: buf, .offset: offset, .size: ggml_vk_get_max_buffer_range(ctx, buf, offset) }; |
| 2319 | } |
| 2320 | |
| 2321 | static void ggml_vk_sync_buffers(ggml_backend_vk_context* ctx, vk_context& subctx) { |
| 2322 | VK_LOG_DEBUG("ggml_vk_sync_buffers()" ); |
| 2323 | |
| 2324 | const bool transfer_queue = subctx->p->q->transfer_only; |
| 2325 | |
| 2326 | if (ctx) { |
| 2327 | ctx->prealloc_x_need_sync = ctx->prealloc_y_need_sync = ctx->prealloc_split_k_need_sync = false; |
| 2328 | } |
| 2329 | |
| 2330 | subctx->s->buffer.pipelineBarrier( |
| 2331 | srcStageMask: subctx->p->q->stage_flags, |
| 2332 | dstStageMask: subctx->p->q->stage_flags, |
| 2333 | dependencyFlags: {}, |
| 2334 | memoryBarriers: { { |
| 2335 | { !transfer_queue ? (vk::AccessFlagBits::eShaderRead | vk::AccessFlagBits::eShaderWrite | vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) : (vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) }, |
| 2336 | { !transfer_queue ? (vk::AccessFlagBits::eShaderRead | vk::AccessFlagBits::eShaderWrite | vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) : (vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) } |
| 2337 | } }, |
| 2338 | bufferMemoryBarriers: {}, |
| 2339 | imageMemoryBarriers: {} |
| 2340 | ); |
| 2341 | } |
| 2342 | |
| 2343 | static void ggml_vk_wait_events(vk_context& ctx, std::vector<vk::Event>&& events) { |
| 2344 | VK_LOG_DEBUG("ggml_vk_wait_events()" ); |
| 2345 | if (events.empty()) { |
| 2346 | return; |
| 2347 | } |
| 2348 | |
| 2349 | ctx->s->buffer.waitEvents( |
| 2350 | events, |
| 2351 | srcStageMask: ctx->p->q->stage_flags, |
| 2352 | dstStageMask: ctx->p->q->stage_flags, |
| 2353 | memoryBarriers: {}, |
| 2354 | bufferMemoryBarriers: {}, |
| 2355 | imageMemoryBarriers: {} |
| 2356 | ); |
| 2357 | } |
| 2358 | |
| 2359 | // number of rows/cols for flash attention shader |
| 2360 | static constexpr uint32_t flash_attention_num_small_rows = 32; |
| 2361 | static constexpr uint32_t scalar_flash_attention_num_small_rows = 1; |
| 2362 | |
| 2363 | static uint32_t get_fa_scalar_num_large_rows(uint32_t hsv) { |
| 2364 | if (hsv >= 192) { |
| 2365 | return 2; |
| 2366 | } else { |
| 2367 | return 8; |
| 2368 | } |
| 2369 | } |
| 2370 | |
| 2371 | // The FA coopmat1 shader assumes 16x16x16 matrix multiply support. |
| 2372 | // 128 threads split into four subgroups, each subgroup does 1/4 |
| 2373 | // of the Bc dimension. |
| 2374 | static constexpr uint32_t coopmat1_flash_attention_num_large_rows = 16; |
| 2375 | static constexpr uint32_t scalar_flash_attention_Bc = 64; |
| 2376 | static constexpr uint32_t scalar_flash_attention_workgroup_size = 128; |
| 2377 | |
| 2378 | static uint32_t get_fa_num_small_rows(FaCodePath path) { |
| 2379 | if (path == FA_COOPMAT2) { |
| 2380 | return flash_attention_num_small_rows; |
| 2381 | } else { |
| 2382 | return scalar_flash_attention_num_small_rows; |
| 2383 | } |
| 2384 | } |
| 2385 | |
| 2386 | static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) { |
| 2387 | GGML_UNUSED(clamp); |
| 2388 | GGML_UNUSED(hsv); |
| 2389 | |
| 2390 | if (path == FA_SCALAR) { |
| 2391 | if (small_rows) { |
| 2392 | return {scalar_flash_attention_num_small_rows, 64}; |
| 2393 | } else { |
| 2394 | if ((hsv | hsk) & 8) { |
| 2395 | // HSV/HSK not being a multiple of 16 makes D_split smaller, which makes cols_per_iter |
| 2396 | // larger, and Bc needs to be >= cols_per_thread. 64 is large enough, 32 is not. |
| 2397 | return {get_fa_scalar_num_large_rows(hsv), 64}; |
| 2398 | } else { |
| 2399 | return {get_fa_scalar_num_large_rows(hsv), 32}; |
| 2400 | } |
| 2401 | } |
| 2402 | } |
| 2403 | |
| 2404 | if (path == FA_COOPMAT1) { |
| 2405 | if (small_rows) { |
| 2406 | return {scalar_flash_attention_num_small_rows, scalar_flash_attention_Bc}; |
| 2407 | } else { |
| 2408 | return {coopmat1_flash_attention_num_large_rows, scalar_flash_attention_Bc}; |
| 2409 | } |
| 2410 | } |
| 2411 | |
| 2412 | // small rows, large cols |
| 2413 | if (small_rows) { |
| 2414 | return {get_fa_num_small_rows(path: FA_COOPMAT2), 32}; |
| 2415 | } |
| 2416 | |
| 2417 | // small cols to reduce register count |
| 2418 | if (ggml_is_quantized(type) || hsk >= 256 || hsv >= 256) { |
| 2419 | if (hsk >= 512 || hsv >= 512) { |
| 2420 | return {32, 32}; |
| 2421 | } else { |
| 2422 | return {64, 32}; |
| 2423 | } |
| 2424 | } |
| 2425 | return {64, 64}; |
| 2426 | } |
| 2427 | |
| 2428 | static uint32_t fa_align(FaCodePath path, uint32_t hsk, uint32_t hsv, ggml_type type, bool small_rows) { |
| 2429 | return fa_rows_cols(path, hsk, hsv, clamp: 0, type, small_rows)[1]; |
| 2430 | } |
| 2431 | |
| 2432 | static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector<uint32_t>& warptile, bool mul_mat_id, ggml_type src0_type) { |
| 2433 | |
| 2434 | uint32_t lut_size = 0; |
| 2435 | switch (src0_type) { |
| 2436 | case GGML_TYPE_IQ1_S: |
| 2437 | case GGML_TYPE_IQ1_M: |
| 2438 | lut_size = 2*2048; |
| 2439 | break; |
| 2440 | case GGML_TYPE_IQ2_XXS: |
| 2441 | lut_size = 8*256; |
| 2442 | break; |
| 2443 | case GGML_TYPE_IQ2_XS: |
| 2444 | lut_size = 8*512; |
| 2445 | break; |
| 2446 | case GGML_TYPE_IQ2_S: |
| 2447 | lut_size = 8*1024; |
| 2448 | break; |
| 2449 | case GGML_TYPE_IQ3_XXS: |
| 2450 | lut_size = 4*256; |
| 2451 | break; |
| 2452 | case GGML_TYPE_IQ3_S: |
| 2453 | lut_size = 4*512; |
| 2454 | break; |
| 2455 | case GGML_TYPE_IQ4_NL: |
| 2456 | case GGML_TYPE_IQ4_XS: |
| 2457 | case GGML_TYPE_MXFP4: |
| 2458 | lut_size = 4*16; |
| 2459 | break; |
| 2460 | default: |
| 2461 | break; |
| 2462 | } |
| 2463 | |
| 2464 | // Needs to be kept up to date on shader changes |
| 2465 | const uint32_t bank_conflict_offset = device->coopmat_support ? 8 : 1; |
| 2466 | const uint32_t type_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float); |
| 2467 | const uint32_t warps = warptile[0] / warptile[10]; |
| 2468 | |
| 2469 | const uint32_t load_bufs = (warptile[1] + warptile[2]) * (warptile[3] + bank_conflict_offset) * type_size; |
| 2470 | const uint32_t mmid_row_ids = mul_mat_id ? (warptile[2] * 2 * sizeof(uint16_t)) : 0; |
| 2471 | const uint32_t coopmat_stage = device->coopmat_support ? warptile[7] * warptile[8] / warps * sizeof(float) : 0; |
| 2472 | const uint32_t ballots_sh = mul_mat_id ? (warps * 4 * sizeof(uint32_t)) : 0; |
| 2473 | |
| 2474 | const uint32_t total_size = load_bufs + mmid_row_ids + coopmat_stage + lut_size + ballots_sh; |
| 2475 | const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize; |
| 2476 | |
| 2477 | VK_LOG_DEBUG("ggml_vk_matmul_shmem_support(warptile=(" << warptile[0] << "," << warptile[1] << "," << warptile[2] << "), " |
| 2478 | "mul_mat_id=" << mul_mat_id << ", src0_type=" << ggml_type_name(src0_type) << ", supported=" << supported); |
| 2479 | |
| 2480 | return supported; |
| 2481 | } |
| 2482 | |
| 2483 | struct GpuPipelineConfig { |
| 2484 | // GPU architecture identifier. |
| 2485 | // Example: vk_device_architecture::AMD_GCN |
| 2486 | vk_device_architecture arch; |
| 2487 | |
| 2488 | // Mapping of pipeline names to their specific subgroup sizes. |
| 2489 | // Example: {"soft_max_f32", 64} |
| 2490 | std::unordered_map<std::string, uint32_t> pipelines; |
| 2491 | |
| 2492 | // Default subgroup size for this GPU. |
| 2493 | // Defaults to 0 if not explicitly provided. |
| 2494 | uint32_t default_subgroup_size = 0; |
| 2495 | }; |
| 2496 | |
| 2497 | // Pipeline configuration for RDNA1 GPUs. |
| 2498 | static const std::unordered_map<std::string, uint32_t> rdna1_pipelines = { |
| 2499 | {"soft_max" , 64}, {"im2col" , 64}, |
| 2500 | {"argmax" , 64}, {"mul_mat_vec" , 64}, |
| 2501 | {"mul_mat_vec_f16" , 32}, {"mul_mat_vec_f32_f16" , 32} |
| 2502 | }; |
| 2503 | |
| 2504 | // Pipeline configuration for RDNA2 GPUs. |
| 2505 | static const std::unordered_map<std::string, uint32_t> rdna2_pipelines = { |
| 2506 | {"soft_max" , 64}, {"im2col" , 64}, |
| 2507 | }; |
| 2508 | |
| 2509 | static constexpr uint32_t RDNA_DEFAULT_SUBGROUP_SIZE = 32; |
| 2510 | |
| 2511 | // Define configurations for different GPUs. |
| 2512 | static std::vector<GpuPipelineConfig> gpu_pipeline_configs = { |
| 2513 | { |
| 2514 | .arch: vk_device_architecture::AMD_RDNA1, |
| 2515 | .pipelines: { |
| 2516 | rdna1_pipelines, |
| 2517 | }, |
| 2518 | .default_subgroup_size: RDNA_DEFAULT_SUBGROUP_SIZE |
| 2519 | }, |
| 2520 | { |
| 2521 | .arch: vk_device_architecture::AMD_RDNA2, |
| 2522 | .pipelines: { |
| 2523 | rdna2_pipelines, |
| 2524 | }, |
| 2525 | .default_subgroup_size: RDNA_DEFAULT_SUBGROUP_SIZE |
| 2526 | }, |
| 2527 | }; |
| 2528 | |
| 2529 | static uint32_t get_subgroup_size(const std::string &pipeline_name, const vk_device_architecture &arch) { |
| 2530 | for (const auto &config : gpu_pipeline_configs) { |
| 2531 | if (config.arch == arch) { |
| 2532 | auto pipIt = config.pipelines.find(x: pipeline_name); |
| 2533 | if (pipIt != config.pipelines.end()) { |
| 2534 | return pipIt->second; |
| 2535 | } |
| 2536 | std::vector<std::pair<std::string, uint32_t>> sorted_pipelines(config.pipelines.begin(), config.pipelines.end()); |
| 2537 | std::sort(first: sorted_pipelines.begin(), last: sorted_pipelines.end(), |
| 2538 | comp: [](const auto &a, const auto &b) { return a.first.size() > b.first.size(); }); |
| 2539 | for (const auto &entry : sorted_pipelines) { |
| 2540 | if (pipeline_name.find(str: entry.first) != std::string::npos) { |
| 2541 | return entry.second; |
| 2542 | } |
| 2543 | } |
| 2544 | return config.default_subgroup_size; |
| 2545 | } |
| 2546 | } |
| 2547 | return 0; // If no matching configuration is found |
| 2548 | } |
| 2549 | |
| 2550 | static void ggml_vk_load_shaders(vk_device& device) { |
| 2551 | VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")" ); |
| 2552 | |
| 2553 | std::lock_guard<std::recursive_mutex> guard(device->mutex); |
| 2554 | // some shaders have a minimum subgroup size |
| 2555 | const uint32_t subgroup_size_8 = std::max(a: device->subgroup_size, b: 8u); |
| 2556 | const uint32_t subgroup_size_16 = std::max(a: device->subgroup_size, b: 16u); |
| 2557 | const uint32_t subgroup_size_32 = std::max(a: device->subgroup_size, b: 32u); |
| 2558 | |
| 2559 | const uint32_t mul_mat_subgroup_size = (device->vendor_id == VK_VENDOR_ID_INTEL && device->subgroup_size_control) ? device->subgroup_min_size : device->subgroup_size; |
| 2560 | const uint32_t mul_mat_subgroup_size_8 = std::max(a: mul_mat_subgroup_size, b: 8u); |
| 2561 | const uint32_t mul_mat_subgroup_size_16 = std::max(a: mul_mat_subgroup_size, b: 16u); |
| 2562 | const uint32_t mul_mat_subgroup_size_32 = std::max(a: mul_mat_subgroup_size, b: 32u); |
| 2563 | |
| 2564 | const bool subgroup_min_size_16 = (!device->subgroup_size_control && device->subgroup_size >= 16) || |
| 2565 | (device->subgroup_size_control && device->subgroup_max_size >= 16); |
| 2566 | |
| 2567 | // mulmat |
| 2568 | std::vector<uint32_t> l_warptile, m_warptile, s_warptile, |
| 2569 | l_warptile_id, m_warptile_id, s_warptile_id, |
| 2570 | l_warptile_mmq, m_warptile_mmq, s_warptile_mmq, |
| 2571 | l_warptile_mmq_int, m_warptile_mmq_int, s_warptile_mmq_int, |
| 2572 | l_warptile_mmq_int_k, m_warptile_mmq_int_k, s_warptile_mmq_int_k, |
| 2573 | l_warptile_mmq_k, m_warptile_mmq_k, s_warptile_mmq_k, |
| 2574 | l_warptile_mmqid, m_warptile_mmqid, s_warptile_mmqid, |
| 2575 | l_warptile_mmqid_int, m_warptile_mmqid_int, s_warptile_mmqid_int, |
| 2576 | l_warptile_mmqid_int_k, m_warptile_mmqid_int_k, s_warptile_mmqid_int_k; |
| 2577 | std::array<uint32_t, 3> l_wg_denoms, m_wg_denoms, s_wg_denoms, |
| 2578 | l_mmq_wg_denoms, m_mmq_wg_denoms, s_mmq_wg_denoms, |
| 2579 | l_mmq_wg_denoms_k, m_mmq_wg_denoms_k, s_mmq_wg_denoms_k, |
| 2580 | l_mmqid_wg_denoms, m_mmqid_wg_denoms, s_mmqid_wg_denoms; |
| 2581 | |
| 2582 | uint32_t l_align, m_align, s_align; |
| 2583 | if (device->coopmat2) { |
| 2584 | // spec constants and tile sizes for non-quant matmul/matmul_id |
| 2585 | l_warptile = { 256, 128, 256, 64, 1 }; |
| 2586 | m_warptile = { 256, 128, 128, 64, 0 }; |
| 2587 | s_warptile = { 128, 64, 64, 64, 0 }; |
| 2588 | l_wg_denoms = {128, 256, 1 }; |
| 2589 | m_wg_denoms = {128, 128, 1 }; |
| 2590 | s_wg_denoms = { 64, 64, 1 }; |
| 2591 | |
| 2592 | // spec constants and tile sizes for quant matmul (non-Qi_K) |
| 2593 | l_warptile_mmq = { 256, 128, 256, 64, 1 }; |
| 2594 | m_warptile_mmq = { 256, 128, 128, 64, 1 }; |
| 2595 | s_warptile_mmq = { 256, 32, 64, 128, 0 }; |
| 2596 | l_mmq_wg_denoms = { 128, 256, 1 }; |
| 2597 | m_mmq_wg_denoms = { 128, 128, 1 }; |
| 2598 | s_mmq_wg_denoms = { 32, 64, 1 }; |
| 2599 | |
| 2600 | // spec constants and tile sizes for quant matmul (Qi_K) |
| 2601 | l_warptile_mmq_k = { 256, 128, 256, 64, 1 }; |
| 2602 | m_warptile_mmq_k = { 256, 128, 128, 64, 1 }; |
| 2603 | s_warptile_mmq_k = { 256, 32, 64, 128, 0 }; |
| 2604 | l_mmq_wg_denoms_k = { 128, 256, 1 }; |
| 2605 | m_mmq_wg_denoms_k = { 128, 128, 1 }; |
| 2606 | s_mmq_wg_denoms_k = { 32, 64, 1 }; |
| 2607 | |
| 2608 | // spec constants and tile sizes for quant matmul_id |
| 2609 | l_warptile_mmqid = { 256, 128, 128, 16, 1, device->subgroup_size }; |
| 2610 | m_warptile_mmqid = { 256, 128, 64, 16, 0, device->subgroup_size }; |
| 2611 | s_warptile_mmqid = { 256, 128, 64, 16, 0, device->subgroup_size }; |
| 2612 | l_mmqid_wg_denoms = { 128, 128, 1 }; |
| 2613 | m_mmqid_wg_denoms = { 128, 64, 1 }; |
| 2614 | s_mmqid_wg_denoms = { 128, 64, 1 }; |
| 2615 | |
| 2616 | l_align = 128; |
| 2617 | m_align = 64; |
| 2618 | s_align = 32; |
| 2619 | } else { |
| 2620 | // Matrix cores require different warp group sizes |
| 2621 | const uint32_t tm_l = device->coopmat_support ? device->coopmat_m : 4; |
| 2622 | const uint32_t tm_m = device->coopmat_support ? device->coopmat_m : 4; |
| 2623 | const uint32_t tm_s = device->coopmat_support ? device->coopmat_m : 2; |
| 2624 | const uint32_t tn_l = device->coopmat_support ? device->coopmat_n : 4; |
| 2625 | const uint32_t tn_m = device->coopmat_support ? device->coopmat_n : 2; |
| 2626 | const uint32_t tn_s = device->coopmat_support ? device->coopmat_n : 2; |
| 2627 | const uint32_t tk_l = device->coopmat_support ? device->coopmat_k : 1; |
| 2628 | const uint32_t tk_m = device->coopmat_support ? device->coopmat_k : 1; |
| 2629 | const uint32_t tk_s = device->coopmat_support ? device->coopmat_k : 1; |
| 2630 | |
| 2631 | l_warptile = { 128, 128, 128, 16, subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, subgroup_size_8 }; |
| 2632 | m_warptile = { 128, 64, 64, 16, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 }; |
| 2633 | s_warptile = { subgroup_size_16, 32, 32, 16, 32, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 }; |
| 2634 | |
| 2635 | l_warptile_mmq = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, subgroup_size_8 }; |
| 2636 | m_warptile_mmq = { 128, 64, 64, 32, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 }; |
| 2637 | s_warptile_mmq = { subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 }; |
| 2638 | |
| 2639 | // Integer MMQ has a smaller shared memory profile, but heavier register use |
| 2640 | l_warptile_mmq_int = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 }; |
| 2641 | m_warptile_mmq_int = { 128, 64, 64, 32, subgroup_size_8, 32, 2, 2, 2, 1, subgroup_size_8 }; |
| 2642 | s_warptile_mmq_int = { subgroup_size_32, 32, 32, 32, 32, 32, 2, 2, 1, 1, subgroup_size_8 }; |
| 2643 | |
| 2644 | // K-quants use even more registers, mitigate by setting WMITER to 1 |
| 2645 | l_warptile_mmq_int_k = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 1, 4, 4, 1, subgroup_size_8 }; |
| 2646 | m_warptile_mmq_int_k = { 128, 64, 64, 32, subgroup_size_8, 32, 1, 2, 2, 1, subgroup_size_8 }; |
| 2647 | s_warptile_mmq_int_k = { subgroup_size_32, 32, 32, 32, 32, 32, 1, 2, 1, 1, subgroup_size_8 }; |
| 2648 | |
| 2649 | l_warptile_id = { 128, 128, 128, 16, mul_mat_subgroup_size_16 * 2, 64, 2, tm_l, tn_l, tk_l, mul_mat_subgroup_size_16 }; |
| 2650 | m_warptile_id = { 128, 64, 64, 16, mul_mat_subgroup_size_16, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_16 }; |
| 2651 | s_warptile_id = { mul_mat_subgroup_size_16, 32, 32, 16, 32, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_16 }; |
| 2652 | |
| 2653 | l_warptile_mmqid = { 128, 128, 128, 32, mul_mat_subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, mul_mat_subgroup_size_8 }; |
| 2654 | m_warptile_mmqid = { 128, 64, 64, 32, mul_mat_subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_8 }; |
| 2655 | s_warptile_mmqid = { mul_mat_subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_8 }; |
| 2656 | |
| 2657 | l_warptile_mmqid_int = { 128, 128, 128, 32, mul_mat_subgroup_size_8 * 2, 64, 2, 4, 4, 1, mul_mat_subgroup_size_8 }; |
| 2658 | m_warptile_mmqid_int = { 128, 64, 64, 32, mul_mat_subgroup_size_8, 32, 2, 2, 2, 1, mul_mat_subgroup_size_8 }; |
| 2659 | s_warptile_mmqid_int = { mul_mat_subgroup_size_32, 32, 32, 32, 32, 32, 2, 2, 1, 1, mul_mat_subgroup_size_8 }; |
| 2660 | |
| 2661 | l_warptile_mmqid_int_k = { 128, 128, 128, 32, mul_mat_subgroup_size_16 * 2, 64, 1, 4, 4, 1, mul_mat_subgroup_size_16 }; |
| 2662 | m_warptile_mmqid_int_k = { 128, 64, 64, 32, mul_mat_subgroup_size_16, 32, 1, 2, 2, 1, mul_mat_subgroup_size_16 }; |
| 2663 | s_warptile_mmqid_int_k = { mul_mat_subgroup_size_32, 32, 32, 32, 32, 32, 1, 2, 1, 1, mul_mat_subgroup_size_16 }; |
| 2664 | |
| 2665 | // chip specific tuning |
| 2666 | if ((device->architecture == AMD_GCN) && (device->driver_id != vk::DriverId::eAmdProprietary)) { |
| 2667 | m_warptile_mmq = m_warptile_mmq_int = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 }; |
| 2668 | m_warptile_mmqid = m_warptile_mmqid_int = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 }; |
| 2669 | } |
| 2670 | |
| 2671 | l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 }; |
| 2672 | m_mmq_wg_denoms = m_wg_denoms = { 64, 64, 1 }; |
| 2673 | s_mmq_wg_denoms = s_wg_denoms = { 32, 32, 1 }; |
| 2674 | l_align = 128; |
| 2675 | m_align = 64; |
| 2676 | s_align = 32; |
| 2677 | |
| 2678 | for (uint32_t i = 0; i < GGML_TYPE_COUNT; ++i) { |
| 2679 | ggml_type t = (ggml_type)i; |
| 2680 | // Disable medium and large matrix multiplication if not enough shared memory is available |
| 2681 | // Check mmq warptiles as the largest configuration |
| 2682 | // Throw an error if not enough for any matrix multiplication is available |
| 2683 | if (!ggml_vk_matmul_shmem_support(device, warptile: s_warptile_mmq, mul_mat_id: false, src0_type: t)) { |
| 2684 | std::cerr << "ggml_vulkan: Error: Shared memory size too small for matrix multiplication." << std::endl; |
| 2685 | throw std::runtime_error("Shared memory size too small for matrix multiplication." ); |
| 2686 | } else if (!ggml_vk_matmul_shmem_support(device, warptile: m_warptile_mmq, mul_mat_id: false, src0_type: t)) { |
| 2687 | device->mul_mat_m[i] = false; |
| 2688 | device->mul_mat_l[i] = false; |
| 2689 | } else if (!ggml_vk_matmul_shmem_support(device, warptile: l_warptile_mmq, mul_mat_id: false, src0_type: t)) { |
| 2690 | device->mul_mat_l[i] = false; |
| 2691 | } |
| 2692 | |
| 2693 | // Disable mul_mat_id if not enough shared memory is available |
| 2694 | if (!ggml_vk_matmul_shmem_support(device, warptile: s_warptile_mmqid, mul_mat_id: true, src0_type: t)) { |
| 2695 | device->mul_mat_id_s[i] = false; |
| 2696 | device->mul_mat_id_m[i] = false; |
| 2697 | device->mul_mat_id_l[i] = false; |
| 2698 | } else if (!ggml_vk_matmul_shmem_support(device, warptile: m_warptile_mmqid, mul_mat_id: true, src0_type: t)) { |
| 2699 | device->mul_mat_id_m[i] = false; |
| 2700 | device->mul_mat_id_l[i] = false; |
| 2701 | } else if (!ggml_vk_matmul_shmem_support(device, warptile: l_warptile_mmqid, mul_mat_id: true, src0_type: t)) { |
| 2702 | device->mul_mat_id_l[i] = false; |
| 2703 | } |
| 2704 | } |
| 2705 | } |
| 2706 | |
| 2707 | if (!device->pipeline_matmul_f32) { |
| 2708 | device->pipeline_matmul_f32 = std::make_shared<vk_matmul_pipeline_struct>(); |
| 2709 | } |
| 2710 | if (!device->pipeline_matmul_f32_f16) { |
| 2711 | device->pipeline_matmul_f32_f16 = std::make_shared<vk_matmul_pipeline_struct>(); |
| 2712 | } |
| 2713 | if (!device->pipeline_matmul_id_f32) { |
| 2714 | device->pipeline_matmul_id_f32 = std::make_shared<vk_matmul_pipeline_struct>(); |
| 2715 | } |
| 2716 | if (!device->pipeline_matmul_bf16) { |
| 2717 | device->pipeline_matmul_bf16 = std::make_shared<vk_matmul_pipeline_struct>(); |
| 2718 | } |
| 2719 | if (!device->pipeline_matmul_id_bf16) { |
| 2720 | device->pipeline_matmul_id_bf16 = std::make_shared<vk_matmul_pipeline_struct>(); |
| 2721 | } |
| 2722 | |
| 2723 | std::vector<std::future<void>> compiles; |
| 2724 | auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const char *name, size_t spv_size, const void* spv_data, const char *entrypoint, |
| 2725 | uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, const std::vector<uint32_t>& specialization_constants, |
| 2726 | uint32_t align, bool disable_robustness = false, bool require_full_subgroups = false, uint32_t required_subgroup_size = 0) { |
| 2727 | |
| 2728 | if (!require_full_subgroups && required_subgroup_size == 0) { |
| 2729 | required_subgroup_size = get_subgroup_size(pipeline_name: name, arch: device->architecture); |
| 2730 | } |
| 2731 | |
| 2732 | if (!pipeline) { |
| 2733 | pipeline = std::make_shared<vk_pipeline_struct>(); |
| 2734 | } |
| 2735 | if (!pipeline->initialized) { |
| 2736 | pipeline->name = name; |
| 2737 | pipeline->parameter_count = parameter_count; |
| 2738 | pipeline->push_constant_size = push_constant_size; |
| 2739 | pipeline->wg_denoms = wg_denoms; |
| 2740 | pipeline->align = align; |
| 2741 | pipeline->initialized = true; |
| 2742 | } |
| 2743 | |
| 2744 | if (!pipeline->needed || pipeline->compiled) { |
| 2745 | return; |
| 2746 | } |
| 2747 | // TODO: We're no longer benefitting from the async compiles (shaders are |
| 2748 | // compiled individually, as needed) and this complexity can be removed. |
| 2749 | { |
| 2750 | // wait until fewer than N compiles are in progress |
| 2751 | uint32_t N = std::max(a: 1u, b: std::thread::hardware_concurrency()); |
| 2752 | std::unique_lock<std::mutex> guard(compile_count_mutex); |
| 2753 | while (compile_count >= N) { |
| 2754 | compile_count_cond.wait(lock&: guard); |
| 2755 | } |
| 2756 | compile_count++; |
| 2757 | } |
| 2758 | |
| 2759 | compiles.push_back(x: std::async(fn&: ggml_vk_create_pipeline_func, args: std::ref(t&: device), args: std::ref(t&: pipeline), args&: spv_size, args&: spv_data, args&: entrypoint, |
| 2760 | args&: parameter_count, args&: wg_denoms, args: specialization_constants, args&: disable_robustness, args&: require_full_subgroups, args&: required_subgroup_size)); |
| 2761 | }; |
| 2762 | |
| 2763 | auto const &ggml_vk_create_pipeline2 = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const char *entrypoint, |
| 2764 | uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, const std::vector<uint32_t>& specialization_constants, |
| 2765 | uint32_t align, bool disable_robustness = false, bool require_full_subgroups = false, uint32_t required_subgroup_size = 0) { |
| 2766 | return ggml_vk_create_pipeline(device, pipeline, name.c_str(), spv_size, spv_data, entrypoint, |
| 2767 | parameter_count, push_constant_size, wg_denoms, specialization_constants, |
| 2768 | align, disable_robustness, require_full_subgroups, required_subgroup_size); |
| 2769 | }; |
| 2770 | |
| 2771 | auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) -> std::array<uint32_t, 3> { |
| 2772 | return {fa_rows_cols(path, hsk, hsv, clamp, type, small_rows)[0], 1, 1}; |
| 2773 | }; |
| 2774 | |
| 2775 | auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector<uint32_t> { |
| 2776 | // For large number of rows, 128 invocations seems to work best. |
| 2777 | // For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we |
| 2778 | // can't use 256 for D==80. |
| 2779 | // For scalar, use 128 (arbitrary) |
| 2780 | // The same D_split value is used for both HSK and HSV, so just base it on the union of the LSBs. |
| 2781 | const uint32_t D = (hsk|hsv); |
| 2782 | uint32_t wg_size = (path == FA_SCALAR || path == FA_COOPMAT1) |
| 2783 | ? scalar_flash_attention_workgroup_size |
| 2784 | : ((small_rows && (D % 32) == 0) ? 256 : 128); |
| 2785 | auto rows_cols = fa_rows_cols(path, hsk, hsv, clamp, type, small_rows); |
| 2786 | |
| 2787 | // D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it. |
| 2788 | // D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader. |
| 2789 | const uint32_t D_lsb = D ^ (D & (D-1)); |
| 2790 | uint32_t D_split = std::min(a: std::min(a: device->subgroup_size, b: 8u), b: D_lsb / 4); |
| 2791 | |
| 2792 | return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split}; |
| 2793 | }; |
| 2794 | |
| 2795 | #define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \ |
| 2796 | for (auto &fa : device->pipeline_flash_attn_f32_f16[TYPE]) { \ |
| 2797 | uint32_t HSK = fa.first.HSK; \ |
| 2798 | uint32_t HSV = fa.first.HSV; \ |
| 2799 | bool small_rows = fa.first.small_rows; \ |
| 2800 | FaCodePath path = fa.first.path; \ |
| 2801 | bool aligned = fa.first.aligned; \ |
| 2802 | bool f32acc = fa.first.f32acc; \ |
| 2803 | if (path == FAPATH) { \ |
| 2804 | if (aligned) { \ |
| 2805 | if (f32acc) { \ |
| 2806 | ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_align(FAPATH,HSK,HSV,TYPE,small_rows), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ |
| 2807 | } else { \ |
| 2808 | ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_align(FAPATH,HSK,HSV,TYPE,small_rows), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ |
| 2809 | } \ |
| 2810 | } else { \ |
| 2811 | if (f32acc) { \ |
| 2812 | ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ |
| 2813 | } else { \ |
| 2814 | ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ |
| 2815 | } \ |
| 2816 | } \ |
| 2817 | } \ |
| 2818 | } |
| 2819 | |
| 2820 | CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, ) |
| 2821 | CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, ) |
| 2822 | CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, ) |
| 2823 | CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, ) |
| 2824 | #if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) |
| 2825 | if (device->coopmat1_fa_support) { |
| 2826 | CREATE_FA(GGML_TYPE_F32, f32, FA_COOPMAT1, _cm1) |
| 2827 | CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT1, _cm1) |
| 2828 | CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT1, _cm1) |
| 2829 | CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT1, _cm1) |
| 2830 | } |
| 2831 | #endif |
| 2832 | #if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) |
| 2833 | if (device->coopmat2) { |
| 2834 | CREATE_FA(GGML_TYPE_F32, f32, FA_COOPMAT2, _cm2) |
| 2835 | CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT2, _cm2) |
| 2836 | CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT2, _cm2) |
| 2837 | CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_COOPMAT2, _cm2) |
| 2838 | CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_COOPMAT2, _cm2) |
| 2839 | CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_COOPMAT2, _cm2) |
| 2840 | CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT2, _cm2) |
| 2841 | CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_COOPMAT2, _cm2) |
| 2842 | } |
| 2843 | #endif |
| 2844 | #undef CREATE_FA |
| 2845 | |
| 2846 | #if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) |
| 2847 | if (device->coopmat2) { |
| 2848 | |
| 2849 | // Create 6 variants, {s,m,l}x{unaligned,aligned} |
| 2850 | #define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \ |
| 2851 | ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \ |
| 2852 | ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \ |
| 2853 | ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \ |
| 2854 | ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align); \ |
| 2855 | ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align); \ |
| 2856 | ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \ |
| 2857 | |
| 2858 | // Create 2 variants, {f16,f32} accumulator |
| 2859 | #define CREATE_MM2(PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \ |
| 2860 | CREATE_MM(PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \ |
| 2861 | CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \ |
| 2862 | |
| 2863 | CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3) |
| 2864 | #if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) |
| 2865 | if (device->coopmat_bf16_support) { |
| 2866 | CREATE_MM(pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3) |
| 2867 | } |
| 2868 | #endif |
| 2869 | CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_0], matmul_q4_0_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) |
| 2870 | CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_1], matmul_q4_1_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) |
| 2871 | CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_0], matmul_q5_0_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) |
| 2872 | CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_1], matmul_q5_1_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) |
| 2873 | CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q8_0], matmul_q8_0_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) |
| 2874 | CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q2_K], matmul_q2_k_f16, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) |
| 2875 | CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q3_K], matmul_q3_k_f16, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) |
| 2876 | CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_K], matmul_q4_k_f16, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) |
| 2877 | CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_K], matmul_q5_k_f16, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) |
| 2878 | CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q6_K], matmul_q6_k_f16, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) |
| 2879 | CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ1_S], matmul_iq1_s_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) |
| 2880 | CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ1_M], matmul_iq1_m_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) |
| 2881 | CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_XXS], matmul_iq2_xxs_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) |
| 2882 | CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_XS], matmul_iq2_xs_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) |
| 2883 | CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_S], matmul_iq2_s_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) |
| 2884 | CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ3_XXS], matmul_iq3_xxs_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) |
| 2885 | CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ3_S], matmul_iq3_s_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) |
| 2886 | CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) |
| 2887 | CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) |
| 2888 | CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_MXFP4], matmul_mxfp4_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) |
| 2889 | |
| 2890 | GGML_ASSERT(device->subgroup_ballot); |
| 2891 | |
| 2892 | CREATE_MM2(pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4) |
| 2893 | #if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) |
| 2894 | if (device->coopmat_bf16_support) { |
| 2895 | CREATE_MM(pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4) |
| 2896 | } |
| 2897 | #endif |
| 2898 | CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) |
| 2899 | CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) |
| 2900 | CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) |
| 2901 | CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) |
| 2902 | CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) |
| 2903 | CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) |
| 2904 | CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) |
| 2905 | CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) |
| 2906 | CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) |
| 2907 | CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) |
| 2908 | CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_subgroup_iq1_s_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) |
| 2909 | CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_subgroup_iq1_m_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) |
| 2910 | CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_subgroup_iq2_xxs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) |
| 2911 | CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_subgroup_iq2_xs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) |
| 2912 | CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_subgroup_iq2_s_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) |
| 2913 | CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_subgroup_iq3_xxs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) |
| 2914 | CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_subgroup_iq3_s_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) |
| 2915 | CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) |
| 2916 | CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) |
| 2917 | CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) |
| 2918 | #undef CREATE_MM |
| 2919 | #undef CREATE_MM2 |
| 2920 | } else |
| 2921 | #endif // defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) |
| 2922 | #if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) |
| 2923 | if (device->coopmat_support) { |
| 2924 | // Create 6 variants, {s,m,l}x{unaligned,aligned} |
| 2925 | #define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ |
| 2926 | if (device->mul_mat ## ID ## _l[TYPE]) \ |
| 2927 | ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, true); \ |
| 2928 | if (device->mul_mat ## ID ## _m[TYPE]) \ |
| 2929 | ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, true); \ |
| 2930 | if (device->mul_mat ## ID ## _s[TYPE]) \ |
| 2931 | ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, true); \ |
| 2932 | if (device->mul_mat ## ID ## _l[TYPE]) \ |
| 2933 | ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, true); \ |
| 2934 | if (device->mul_mat ## ID ## _m[TYPE]) \ |
| 2935 | ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, true); \ |
| 2936 | if (device->mul_mat ## ID ## _s[TYPE]) \ |
| 2937 | ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, true); \ |
| 2938 | |
| 2939 | // Create 2 variants, {f16,f32} accumulator |
| 2940 | #define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ |
| 2941 | if (device->coopmat_acc_f16_support) { \ |
| 2942 | CREATE_MM(TYPE, PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ |
| 2943 | } \ |
| 2944 | if (device->coopmat_acc_f32_support) { \ |
| 2945 | CREATE_MM(TYPE, PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ |
| 2946 | } \ |
| 2947 | |
| 2948 | CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); |
| 2949 | CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); |
| 2950 | CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); |
| 2951 | CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); |
| 2952 | #if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) |
| 2953 | if (device->coopmat_bf16_support) { |
| 2954 | CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ) |
| 2955 | } |
| 2956 | #endif |
| 2957 | |
| 2958 | if (device->coopmat_acc_f16_support) { |
| 2959 | CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0], matmul_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); |
| 2960 | CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1], matmul_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); |
| 2961 | CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0], matmul_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); |
| 2962 | CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1], matmul_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); |
| 2963 | CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0], matmul_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); |
| 2964 | |
| 2965 | CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K], matmul_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); |
| 2966 | CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K], matmul_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); |
| 2967 | CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K], matmul_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); |
| 2968 | CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K], matmul_q5_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); |
| 2969 | CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K], matmul_q6_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); |
| 2970 | CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S], matmul_iq1_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); |
| 2971 | CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M], matmul_iq1_m_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); |
| 2972 | CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS], matmul_iq2_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); |
| 2973 | CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS], matmul_iq2_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); |
| 2974 | CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S], matmul_iq2_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); |
| 2975 | CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS], matmul_iq3_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); |
| 2976 | CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S], matmul_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); |
| 2977 | CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); |
| 2978 | CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); |
| 2979 | CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4], matmul_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); |
| 2980 | } else { |
| 2981 | CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); |
| 2982 | CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); |
| 2983 | CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); |
| 2984 | CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); |
| 2985 | CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); |
| 2986 | |
| 2987 | CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); |
| 2988 | CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); |
| 2989 | CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); |
| 2990 | CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); |
| 2991 | CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); |
| 2992 | CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S].f32acc, matmul_iq1_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); |
| 2993 | CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M].f32acc, matmul_iq1_m_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); |
| 2994 | CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f32acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); |
| 2995 | CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f32acc, matmul_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); |
| 2996 | CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f32acc, matmul_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); |
| 2997 | CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f32acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); |
| 2998 | CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); |
| 2999 | CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); |
| 3000 | CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); |
| 3001 | CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4].f32acc, matmul_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); |
| 3002 | } |
| 3003 | |
| 3004 | GGML_ASSERT(device->subgroup_ballot); |
| 3005 | |
| 3006 | CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); |
| 3007 | CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); |
| 3008 | CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_subgroup_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); |
| 3009 | #if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) |
| 3010 | if (device->coopmat_bf16_support) { |
| 3011 | CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); |
| 3012 | } |
| 3013 | #endif |
| 3014 | |
| 3015 | CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); |
| 3016 | CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); |
| 3017 | CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); |
| 3018 | CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); |
| 3019 | CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); |
| 3020 | CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); |
| 3021 | CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); |
| 3022 | CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); |
| 3023 | CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); |
| 3024 | CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); |
| 3025 | CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_subgroup_iq1_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); |
| 3026 | CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_subgroup_iq1_m_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); |
| 3027 | CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_subgroup_iq2_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); |
| 3028 | CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_subgroup_iq2_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); |
| 3029 | CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_subgroup_iq2_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); |
| 3030 | CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_subgroup_iq3_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); |
| 3031 | CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_subgroup_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); |
| 3032 | CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); |
| 3033 | CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); |
| 3034 | CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); |
| 3035 | #undef CREATE_MM2 |
| 3036 | #undef CREATE_MM |
| 3037 | } else |
| 3038 | #endif // defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) |
| 3039 | if (device->fp16) { |
| 3040 | // Create 6 variants, {s,m,l}x{unaligned,aligned} |
| 3041 | #define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \ |
| 3042 | if (device->mul_mat ## ID ## _l[TYPE]) \ |
| 3043 | ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ |
| 3044 | if (device->mul_mat ## ID ## _m[TYPE]) \ |
| 3045 | ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ |
| 3046 | if (device->mul_mat ## ID ## _s[TYPE]) \ |
| 3047 | ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ |
| 3048 | if (device->mul_mat ## ID ## _l[TYPE]) \ |
| 3049 | ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ |
| 3050 | if (device->mul_mat ## ID ## _m[TYPE]) \ |
| 3051 | ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ |
| 3052 | if (device->mul_mat ## ID ## _s[TYPE]) \ |
| 3053 | ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ |
| 3054 | |
| 3055 | #define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \ |
| 3056 | if (device->mul_mat ## ID ## _l[TYPE]) { \ |
| 3057 | ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->l, #NAMELC "_l", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ |
| 3058 | } \ |
| 3059 | if (device->mul_mat ## ID ## _m[TYPE]) { \ |
| 3060 | ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->m, #NAMELC "_m", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ |
| 3061 | } \ |
| 3062 | if (device->mul_mat ## ID ## _s[TYPE]) { \ |
| 3063 | ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->s, #NAMELC "_s", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ |
| 3064 | } \ |
| 3065 | |
| 3066 | // Create 2 variants, {f16,f32} accumulator |
| 3067 | #define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \ |
| 3068 | CREATE_MM(TYPE, PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \ |
| 3069 | CREATE_MM(TYPE, PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \ |
| 3070 | |
| 3071 | CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); |
| 3072 | CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); |
| 3073 | CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); |
| 3074 | CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); |
| 3075 | |
| 3076 | CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); |
| 3077 | |
| 3078 | CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0], matmul_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); |
| 3079 | CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1], matmul_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); |
| 3080 | CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0], matmul_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); |
| 3081 | CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1], matmul_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); |
| 3082 | CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0], matmul_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); |
| 3083 | |
| 3084 | CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K], matmul_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); |
| 3085 | CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K], matmul_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); |
| 3086 | CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K], matmul_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); |
| 3087 | CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K], matmul_q5_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); |
| 3088 | CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K], matmul_q6_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); |
| 3089 | CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S], matmul_iq1_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); |
| 3090 | CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M], matmul_iq1_m_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); |
| 3091 | CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS], matmul_iq2_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); |
| 3092 | CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS], matmul_iq2_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); |
| 3093 | CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S], matmul_iq2_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); |
| 3094 | CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS], matmul_iq3_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); |
| 3095 | CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S], matmul_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); |
| 3096 | CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); |
| 3097 | CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); |
| 3098 | CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4], matmul_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); |
| 3099 | |
| 3100 | #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) |
| 3101 | if (device->integer_dot_product) { |
| 3102 | CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0], matmul_q4_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0); |
| 3103 | CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1], matmul_q4_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0); |
| 3104 | CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0], matmul_q5_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0); |
| 3105 | CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1], matmul_q5_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0); |
| 3106 | CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0], matmul_q8_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0); |
| 3107 | |
| 3108 | CREATE_MMQ(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_MXFP4], matmul_mxfp4_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0); |
| 3109 | |
| 3110 | CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q2_K], matmul_q2_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0); |
| 3111 | CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q3_K], matmul_q3_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0); |
| 3112 | CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_K], matmul_q4_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0); |
| 3113 | CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_K], matmul_q5_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0); |
| 3114 | CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q6_K], matmul_q6_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0); |
| 3115 | } |
| 3116 | #endif |
| 3117 | |
| 3118 | if (device->subgroup_ballot && device->subgroup_require_full_support && subgroup_min_size_16) { |
| 3119 | CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16); |
| 3120 | CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16); |
| 3121 | CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_subgroup_f16_f32, wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16); |
| 3122 | CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16); |
| 3123 | |
| 3124 | CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); |
| 3125 | CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); |
| 3126 | CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); |
| 3127 | CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); |
| 3128 | CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); |
| 3129 | CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); |
| 3130 | CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); |
| 3131 | CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); |
| 3132 | CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); |
| 3133 | CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); |
| 3134 | CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_subgroup_iq1_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); |
| 3135 | CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_subgroup_iq1_m_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); |
| 3136 | CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_subgroup_iq2_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); |
| 3137 | CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_subgroup_iq2_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); |
| 3138 | CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_subgroup_iq2_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); |
| 3139 | CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_subgroup_iq3_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); |
| 3140 | CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_subgroup_iq3_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); |
| 3141 | CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); |
| 3142 | CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); |
| 3143 | CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); |
| 3144 | |
| 3145 | #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) |
| 3146 | if (device->integer_dot_product) { |
| 3147 | CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); |
| 3148 | CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); |
| 3149 | CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); |
| 3150 | CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); |
| 3151 | CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); |
| 3152 | |
| 3153 | CREATE_MMQ(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); |
| 3154 | |
| 3155 | CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16); |
| 3156 | CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16); |
| 3157 | CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16); |
| 3158 | CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16); |
| 3159 | CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16); |
| 3160 | } |
| 3161 | #endif |
| 3162 | } else { |
| 3163 | CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0); |
| 3164 | CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0); |
| 3165 | CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0); |
| 3166 | CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id, 0); |
| 3167 | |
| 3168 | CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); |
| 3169 | CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); |
| 3170 | CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_q5_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); |
| 3171 | CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_q5_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); |
| 3172 | CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_q8_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); |
| 3173 | CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_q2_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); |
| 3174 | CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_q3_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); |
| 3175 | CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_q4_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); |
| 3176 | CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_q5_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); |
| 3177 | CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_q6_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); |
| 3178 | CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_iq1_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); |
| 3179 | CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_iq1_m_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); |
| 3180 | CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_iq2_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); |
| 3181 | CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_iq2_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); |
| 3182 | CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_iq2_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); |
| 3183 | CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_iq3_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); |
| 3184 | CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_iq3_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); |
| 3185 | CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); |
| 3186 | CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); |
| 3187 | CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); |
| 3188 | |
| 3189 | #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) |
| 3190 | if (device->integer_dot_product) { |
| 3191 | CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_0], matmul_id_q4_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0); |
| 3192 | CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_1], matmul_id_q4_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0); |
| 3193 | CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_0], matmul_id_q5_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0); |
| 3194 | CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_1], matmul_id_q5_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0); |
| 3195 | CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q8_0], matmul_id_q8_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0); |
| 3196 | |
| 3197 | CREATE_MMQ(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_MXFP4], matmul_id_mxfp4_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0); |
| 3198 | |
| 3199 | CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q2_K], matmul_id_q2_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0); |
| 3200 | CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q3_K], matmul_id_q3_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0); |
| 3201 | CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_K], matmul_id_q4_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0); |
| 3202 | CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_K], matmul_id_q5_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0); |
| 3203 | CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q6_K], matmul_id_q6_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0); |
| 3204 | } |
| 3205 | #endif |
| 3206 | } |
| 3207 | #undef CREATE_MM2 |
| 3208 | #undef CREATE_MMQ |
| 3209 | #undef CREATE_MM |
| 3210 | } else { |
| 3211 | // Create 6 variants, {s,m,l}x{unaligned,aligned} |
| 3212 | #define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \ |
| 3213 | if (device->mul_mat ## ID ## _l[TYPE]) \ |
| 3214 | ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ |
| 3215 | if (device->mul_mat ## ID ## _m[TYPE]) \ |
| 3216 | ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ |
| 3217 | if (device->mul_mat ## ID ## _s[TYPE]) \ |
| 3218 | ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ |
| 3219 | if (device->mul_mat ## ID ## _l[TYPE]) \ |
| 3220 | ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ |
| 3221 | if (device->mul_mat ## ID ## _m[TYPE]) \ |
| 3222 | ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ |
| 3223 | if (device->mul_mat ## ID ## _s[TYPE]) \ |
| 3224 | ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ |
| 3225 | |
| 3226 | #define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ |
| 3227 | if (device->mul_mat ## ID ## _l[TYPE]) \ |
| 3228 | ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC "_l", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \ |
| 3229 | if (device->mul_mat ## ID ## _m[TYPE]) \ |
| 3230 | ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC "_m", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \ |
| 3231 | if (device->mul_mat ## ID ## _s[TYPE]) \ |
| 3232 | ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC "_s", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \ |
| 3233 | |
| 3234 | CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); |
| 3235 | CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); |
| 3236 | CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); |
| 3237 | CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16_f32.f32acc, matmul_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); |
| 3238 | |
| 3239 | CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); |
| 3240 | |
| 3241 | CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); |
| 3242 | CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); |
| 3243 | CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); |
| 3244 | CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); |
| 3245 | CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); |
| 3246 | |
| 3247 | CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); |
| 3248 | CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); |
| 3249 | CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); |
| 3250 | CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); |
| 3251 | CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); |
| 3252 | CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S].f32acc, matmul_iq1_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); |
| 3253 | CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M].f32acc, matmul_iq1_m_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); |
| 3254 | CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f32acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); |
| 3255 | CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f32acc, matmul_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); |
| 3256 | CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f32acc, matmul_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); |
| 3257 | CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f32acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); |
| 3258 | CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); |
| 3259 | CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); |
| 3260 | CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); |
| 3261 | CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4].f32acc, matmul_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); |
| 3262 | |
| 3263 | #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) |
| 3264 | if (device->integer_dot_product) { |
| 3265 | CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); |
| 3266 | CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); |
| 3267 | CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); |
| 3268 | CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); |
| 3269 | CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); |
| 3270 | |
| 3271 | CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, ); |
| 3272 | CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, ); |
| 3273 | CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, ); |
| 3274 | CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, ); |
| 3275 | CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, ); |
| 3276 | } |
| 3277 | #endif |
| 3278 | |
| 3279 | if (device->subgroup_ballot && device->subgroup_require_full_support && subgroup_min_size_16) { |
| 3280 | CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16); |
| 3281 | CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_subgroup_f16, , wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16); |
| 3282 | CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_subgroup_f16_f32, , wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16); |
| 3283 | CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16); |
| 3284 | |
| 3285 | CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_subgroup_q4_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); |
| 3286 | CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_subgroup_q4_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); |
| 3287 | CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_subgroup_q5_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); |
| 3288 | CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_subgroup_q5_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); |
| 3289 | CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_subgroup_q8_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); |
| 3290 | CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_subgroup_q2_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); |
| 3291 | CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_subgroup_q3_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); |
| 3292 | CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_subgroup_q4_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); |
| 3293 | CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_subgroup_q5_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); |
| 3294 | CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_subgroup_q6_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); |
| 3295 | CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f32acc, matmul_id_subgroup_iq1_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); |
| 3296 | CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f32acc, matmul_id_subgroup_iq1_m_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); |
| 3297 | CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f32acc, matmul_id_subgroup_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); |
| 3298 | CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f32acc, matmul_id_subgroup_iq2_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); |
| 3299 | CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc, matmul_id_subgroup_iq2_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); |
| 3300 | CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f32acc, matmul_id_subgroup_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); |
| 3301 | CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_subgroup_iq3_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); |
| 3302 | CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_subgroup_iq4_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); |
| 3303 | CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_subgroup_iq4_nl_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); |
| 3304 | CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc, matmul_id_subgroup_mxfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); |
| 3305 | } else { |
| 3306 | CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0); |
| 3307 | CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0); |
| 3308 | CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0); |
| 3309 | CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id, 0); |
| 3310 | |
| 3311 | CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); |
| 3312 | CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); |
| 3313 | CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); |
| 3314 | CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); |
| 3315 | CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); |
| 3316 | CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); |
| 3317 | CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); |
| 3318 | CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); |
| 3319 | CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); |
| 3320 | CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); |
| 3321 | CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f32acc, matmul_id_iq1_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); |
| 3322 | CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f32acc, matmul_id_iq1_m_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); |
| 3323 | CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f32acc, matmul_id_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); |
| 3324 | CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f32acc, matmul_id_iq2_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); |
| 3325 | CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc, matmul_id_iq2_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); |
| 3326 | CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f32acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); |
| 3327 | CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); |
| 3328 | CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); |
| 3329 | CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); |
| 3330 | CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc, matmul_id_mxfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); |
| 3331 | } |
| 3332 | } |
| 3333 | // reusing CREATE_MM from the fp32 path |
| 3334 | if ((device->coopmat2 || device->coopmat_support) |
| 3335 | #if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) |
| 3336 | && !device->coopmat_bf16_support |
| 3337 | #endif |
| 3338 | ) { |
| 3339 | // use scalar tile sizes |
| 3340 | l_warptile = { 128, 128, 128, 16, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 }; |
| 3341 | m_warptile = { 128, 64, 64, 16, subgroup_size_8, 32, 2, 4, 2, 1, subgroup_size_8 }; |
| 3342 | s_warptile = { subgroup_size_16, 32, 32, 16, 32, 32, 2, 2, 2, 1, subgroup_size_8 }; |
| 3343 | |
| 3344 | l_wg_denoms = {128, 128, 1 }; |
| 3345 | m_wg_denoms = { 64, 64, 1 }; |
| 3346 | s_wg_denoms = { 32, 32, 1 }; |
| 3347 | |
| 3348 | CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); |
| 3349 | CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id, 0); |
| 3350 | } |
| 3351 | #undef CREATE_MM |
| 3352 | |
| 3353 | // mul mat vec |
| 3354 | |
| 3355 | // the number of rows computed per shader depends on GPU model and quant |
| 3356 | uint32_t rm_stdq = 1; |
| 3357 | uint32_t rm_kq = 2; |
| 3358 | if (device->vendor_id == VK_VENDOR_ID_AMD) { |
| 3359 | if (device->architecture == AMD_GCN) { |
| 3360 | rm_stdq = 2; |
| 3361 | rm_kq = 4; |
| 3362 | } |
| 3363 | } else if (device->vendor_id == VK_VENDOR_ID_INTEL) |
| 3364 | rm_stdq = 2; |
| 3365 | uint32_t rm_iq = 2 * rm_kq; |
| 3366 | |
| 3367 | const bool use_subgroups = device->subgroup_arithmetic && device->architecture != vk_device_architecture::AMD_GCN; |
| 3368 | // Ensure a subgroup size >= 16 is available |
| 3369 | const bool use_subgroups16 = use_subgroups && subgroup_min_size_16; |
| 3370 | |
| 3371 | const uint32_t subgroup_size = (device->vendor_id == VK_VENDOR_ID_INTEL && device->subgroup_size_control && device->subgroup_min_size <= 16 && device->subgroup_max_size >= 16) ? 16 : device->subgroup_size; |
| 3372 | const uint32_t subgroup_size16 = std::max(a: subgroup_size, b: 16u); |
| 3373 | |
| 3374 | const uint32_t force_subgroup_size = use_subgroups ? subgroup_size : 0; |
| 3375 | const uint32_t force_subgroup_size16 = use_subgroups16 ? subgroup_size16 : 0; |
| 3376 | |
| 3377 | for (uint32_t w = 0; w < DMMV_WG_SIZE_COUNT; ++w) { |
| 3378 | const uint32_t wg_size_subgroup = (w == DMMV_WG_SIZE_SUBGROUP) ? subgroup_size : (subgroup_size * 4); |
| 3379 | const uint32_t wg_size_subgroup16 = (w == DMMV_WG_SIZE_SUBGROUP) ? subgroup_size16 : (subgroup_size16 * 4); |
| 3380 | |
| 3381 | const shader_reduction_mode reduc = (use_subgroups && w == DMMV_WG_SIZE_SUBGROUP) ? SHADER_REDUCTION_MODE_SUBGROUP : |
| 3382 | (use_subgroups && w == DMMV_WG_SIZE_LARGE) ? SHADER_REDUCTION_MODE_HYBRID : |
| 3383 | SHADER_REDUCTION_MODE_SHMEM; |
| 3384 | |
| 3385 | const shader_reduction_mode reduc16 = (use_subgroups16 && w == DMMV_WG_SIZE_SUBGROUP) ? SHADER_REDUCTION_MODE_SUBGROUP : |
| 3386 | (use_subgroups16 && w == DMMV_WG_SIZE_LARGE) ? SHADER_REDUCTION_MODE_HYBRID : |
| 3387 | SHADER_REDUCTION_MODE_SHMEM; |
| 3388 | |
| 3389 | for (uint32_t i = 0; i < mul_mat_vec_max_cols; ++i) { |
| 3390 | ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f32_f32" , arr_dmmv_f32_f32_f32_len[reduc], arr_dmmv_f32_f32_f32_data[reduc], "main" , 4, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size); |
| 3391 | ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f32_f32" , arr_dmmv_f16_f32_f32_len[reduc], arr_dmmv_f16_f32_f32_data[reduc], "main" , 4, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size); |
| 3392 | ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f32_f32" , arr_dmmv_bf16_f32_f32_len[reduc], arr_dmmv_bf16_f32_f32_data[reduc], "main" , 4, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size); |
| 3393 | ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f32_f32" , arr_dmmv_q4_0_f32_f32_len[reduc], arr_dmmv_q4_0_f32_f32_data[reduc], "main" , 4, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); |
| 3394 | ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f32_f32" , arr_dmmv_q4_1_f32_f32_len[reduc], arr_dmmv_q4_1_f32_f32_data[reduc], "main" , 4, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); |
| 3395 | ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f32_f32" , arr_dmmv_q5_0_f32_f32_len[reduc], arr_dmmv_q5_0_f32_f32_data[reduc], "main" , 4, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); |
| 3396 | ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_f32_f32" , arr_dmmv_q5_1_f32_f32_len[reduc], arr_dmmv_q5_1_f32_f32_data[reduc], "main" , 4, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); |
| 3397 | ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_f32_f32" , arr_dmmv_q8_0_f32_f32_len[reduc], arr_dmmv_q8_0_f32_f32_data[reduc], "main" , 4, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup, 1*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); |
| 3398 | ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_f32_f32" , arr_dmmv_q2_k_f32_f32_len[reduc16], arr_dmmv_q2_k_f32_f32_data[reduc16], "main" , 4, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); |
| 3399 | ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_f32_f32" , arr_dmmv_q3_k_f32_f32_len[reduc16], arr_dmmv_q3_k_f32_f32_data[reduc16], "main" , 4, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); |
| 3400 | ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f32_f32" , arr_dmmv_q4_k_f32_f32_len[reduc16], arr_dmmv_q4_k_f32_f32_data[reduc16], "main" , 4, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); |
| 3401 | ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f32_f32" , arr_dmmv_q5_k_f32_f32_len[reduc16], arr_dmmv_q5_k_f32_f32_data[reduc16], "main" , 4, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); |
| 3402 | ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f32_f32" , arr_dmmv_q6_k_f32_f32_len[reduc16], arr_dmmv_q6_k_f32_f32_data[reduc16], "main" , 4, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); |
| 3403 | ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ1_S][i], "mul_mat_vec_iq1_s_f32_f32" , arr_dmmv_iq1_s_f32_f32_len[reduc16], arr_dmmv_iq1_s_f32_f32_data[reduc16], "main" , 4, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); |
| 3404 | ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ1_M][i], "mul_mat_vec_iq1_m_f32_f32" , arr_dmmv_iq1_m_f32_f32_len[reduc16], arr_dmmv_iq1_m_f32_f32_data[reduc16], "main" , 4, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); |
| 3405 | ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ2_XXS][i], "mul_mat_vec_iq2_xxs_f32_f32" , arr_dmmv_iq2_xxs_f32_f32_len[reduc16], arr_dmmv_iq2_xxs_f32_f32_data[reduc16], "main" , 4, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); |
| 3406 | ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ2_XS][i], "mul_mat_vec_iq2_xs_f32_f32" , arr_dmmv_iq2_xs_f32_f32_len[reduc16], arr_dmmv_iq2_xs_f32_f32_data[reduc16], "main" , 4, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); |
| 3407 | ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ2_S][i], "mul_mat_vec_iq2_s_f32_f32" , arr_dmmv_iq2_s_f32_f32_len[reduc16], arr_dmmv_iq2_s_f32_f32_data[reduc16], "main" , 4, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); |
| 3408 | ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ3_XXS][i], "mul_mat_vec_iq3_xxs_f32_f32" , arr_dmmv_iq3_xxs_f32_f32_len[reduc16], arr_dmmv_iq3_xxs_f32_f32_data[reduc16], "main" , 4, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); |
| 3409 | ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f32_f32" , arr_dmmv_iq3_s_f32_f32_len[reduc16], arr_dmmv_iq3_s_f32_f32_data[reduc16], "main" , 4, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); |
| 3410 | ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f32_f32" , arr_dmmv_iq4_xs_f32_f32_len[reduc16], arr_dmmv_iq4_xs_f32_f32_data[reduc16], "main" , 4, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); |
| 3411 | ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f32_f32" , arr_dmmv_iq4_nl_f32_f32_len[reduc16], arr_dmmv_iq4_nl_f32_f32_data[reduc16], "main" , 4, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); |
| 3412 | ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_f32_f32" , arr_dmmv_mxfp4_f32_f32_len[reduc16], arr_dmmv_mxfp4_f32_f32_data[reduc16], "main" , 4, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); |
| 3413 | |
| 3414 | ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32" , arr_dmmv_f32_f16_f32_len[reduc], arr_dmmv_f32_f16_f32_data[reduc], "main" , 4, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size); |
| 3415 | ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32" , arr_dmmv_f16_f16_f32_len[reduc], arr_dmmv_f16_f16_f32_data[reduc], "main" , 4, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size); |
| 3416 | ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f16_f32" , arr_dmmv_bf16_f16_f32_len[reduc], arr_dmmv_bf16_f16_f32_data[reduc], "main" , 4, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size); |
| 3417 | ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f16_f32" , arr_dmmv_q4_0_f16_f32_len[reduc], arr_dmmv_q4_0_f16_f32_data[reduc], "main" , 4, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); |
| 3418 | ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f16_f32" , arr_dmmv_q4_1_f16_f32_len[reduc], arr_dmmv_q4_1_f16_f32_data[reduc], "main" , 4, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); |
| 3419 | ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f16_f32" , arr_dmmv_q5_0_f16_f32_len[reduc], arr_dmmv_q5_0_f16_f32_data[reduc], "main" , 4, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); |
| 3420 | ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_f16_f32" , arr_dmmv_q5_1_f16_f32_len[reduc], arr_dmmv_q5_1_f16_f32_data[reduc], "main" , 4, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); |
| 3421 | ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_f16_f32" , arr_dmmv_q8_0_f16_f32_len[reduc], arr_dmmv_q8_0_f16_f32_data[reduc], "main" , 4, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup, 1*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); |
| 3422 | ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_f16_f32" , arr_dmmv_q2_k_f16_f32_len[reduc16], arr_dmmv_q2_k_f16_f32_data[reduc16], "main" , 4, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); |
| 3423 | ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_f16_f32" , arr_dmmv_q3_k_f16_f32_len[reduc16], arr_dmmv_q3_k_f16_f32_data[reduc16], "main" , 4, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); |
| 3424 | ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f16_f32" , arr_dmmv_q4_k_f16_f32_len[reduc16], arr_dmmv_q4_k_f16_f32_data[reduc16], "main" , 4, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); |
| 3425 | ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f16_f32" , arr_dmmv_q5_k_f16_f32_len[reduc16], arr_dmmv_q5_k_f16_f32_data[reduc16], "main" , 4, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); |
| 3426 | ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f16_f32" , arr_dmmv_q6_k_f16_f32_len[reduc16], arr_dmmv_q6_k_f16_f32_data[reduc16], "main" , 4, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); |
| 3427 | ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ1_S][i], "mul_mat_vec_iq1_s_f16_f32" , arr_dmmv_iq1_s_f16_f32_len[reduc16], arr_dmmv_iq1_s_f16_f32_data[reduc16], "main" , 4, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); |
| 3428 | ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ1_M][i], "mul_mat_vec_iq1_m_f16_f32" , arr_dmmv_iq1_m_f16_f32_len[reduc16], arr_dmmv_iq1_m_f16_f32_data[reduc16], "main" , 4, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); |
| 3429 | ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ2_XXS][i], "mul_mat_vec_iq2_xxs_f16_f32" , arr_dmmv_iq2_xxs_f16_f32_len[reduc16], arr_dmmv_iq2_xxs_f16_f32_data[reduc16], "main" , 4, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); |
| 3430 | ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ2_XS][i], "mul_mat_vec_iq2_xs_f16_f32" , arr_dmmv_iq2_xs_f16_f32_len[reduc16], arr_dmmv_iq2_xs_f16_f32_data[reduc16], "main" , 4, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); |
| 3431 | ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ2_S][i], "mul_mat_vec_iq2_s_f16_f32" , arr_dmmv_iq2_s_f16_f32_len[reduc16], arr_dmmv_iq2_s_f16_f32_data[reduc16], "main" , 4, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); |
| 3432 | ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ3_XXS][i], "mul_mat_vec_iq3_xxs_f16_f32" , arr_dmmv_iq3_xxs_f16_f32_len[reduc16], arr_dmmv_iq3_xxs_f16_f32_data[reduc16], "main" , 4, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); |
| 3433 | ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f16_f32" , arr_dmmv_iq3_s_f16_f32_len[reduc16], arr_dmmv_iq3_s_f16_f32_data[reduc16], "main" , 4, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); |
| 3434 | ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f16_f32" , arr_dmmv_iq4_xs_f16_f32_len[reduc16], arr_dmmv_iq4_xs_f16_f32_data[reduc16], "main" , 4, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); |
| 3435 | ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f16_f32" , arr_dmmv_iq4_nl_f16_f32_len[reduc16], arr_dmmv_iq4_nl_f16_f32_data[reduc16], "main" , 4, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); |
| 3436 | ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_f16_f32" , arr_dmmv_mxfp4_f16_f32_len[reduc16], arr_dmmv_mxfp4_f16_f32_data[reduc16], "main" , 4, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); |
| 3437 | |
| 3438 | #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) |
| 3439 | if (device->integer_dot_product) { |
| 3440 | const uint32_t subgroup_size_int = (device->vendor_id == VK_VENDOR_ID_INTEL && device->subgroup_size_control) ? device->subgroup_min_size : device->subgroup_size; |
| 3441 | const uint32_t wg_size_subgroup_int = (w == DMMV_WG_SIZE_SUBGROUP) ? subgroup_size_int : (subgroup_size_int * 4); |
| 3442 | |
| 3443 | ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_q8_1_f32" , arr_dmmv_q4_0_q8_1_f32_len[reduc], arr_dmmv_q4_0_q8_1_f32_data[reduc], "main" , 4, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int); |
| 3444 | ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_q8_1_f32" , arr_dmmv_q4_1_q8_1_f32_len[reduc], arr_dmmv_q4_1_q8_1_f32_data[reduc], "main" , 4, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int); |
| 3445 | ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_q8_1_f32" , arr_dmmv_q5_0_q8_1_f32_len[reduc], arr_dmmv_q5_0_q8_1_f32_data[reduc], "main" , 4, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int); |
| 3446 | ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_q8_1_f32" , arr_dmmv_q5_1_q8_1_f32_len[reduc], arr_dmmv_q5_1_q8_1_f32_data[reduc], "main" , 4, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int); |
| 3447 | ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_q8_1_f32" , arr_dmmv_q8_0_q8_1_f32_len[reduc], arr_dmmv_q8_0_q8_1_f32_data[reduc], "main" , 4, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int); |
| 3448 | } |
| 3449 | #endif // GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT |
| 3450 | } |
| 3451 | } |
| 3452 | |
| 3453 | ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32" , mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main" , 5, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1); |
| 3454 | ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32" , mul_mat_vec_id_f16_f32_len, mul_mat_vec_id_f16_f32_data, "main" , 5, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1); |
| 3455 | ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_BF16], "mul_mat_vec_id_bf16_f32" , mul_mat_vec_id_bf16_f32_len, mul_mat_vec_id_bf16_f32_data, "main" , 5, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1); |
| 3456 | ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32" , mul_mat_vec_id_q4_0_f32_len, mul_mat_vec_id_q4_0_f32_data, "main" , 5, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true); |
| 3457 | ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_f32" , mul_mat_vec_id_q4_1_f32_len, mul_mat_vec_id_q4_1_f32_data, "main" , 5, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true); |
| 3458 | ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32" , mul_mat_vec_id_q5_0_f32_len, mul_mat_vec_id_q5_0_f32_data, "main" , 5, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true); |
| 3459 | ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_f32" , mul_mat_vec_id_q5_1_f32_len, mul_mat_vec_id_q5_1_f32_data, "main" , 5, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true); |
| 3460 | ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_f32" , mul_mat_vec_id_q8_0_f32_len, mul_mat_vec_id_q8_0_f32_data, "main" , 5, sizeof(vk_mat_vec_id_push_constants), {1*rm_stdq, 1, 1}, {device->subgroup_size, 1*rm_stdq}, 1, true); |
| 3461 | ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_f32" , mul_mat_vec_id_q2_k_f32_len, mul_mat_vec_id_q2_k_f32_data, "main" , 5, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); |
| 3462 | ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_f32" , mul_mat_vec_id_q3_k_f32_len, mul_mat_vec_id_q3_k_f32_data, "main" , 5, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); |
| 3463 | ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_f32" , mul_mat_vec_id_q4_k_f32_len, mul_mat_vec_id_q4_k_f32_data, "main" , 5, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); |
| 3464 | ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_f32" , mul_mat_vec_id_q5_k_f32_len, mul_mat_vec_id_q5_k_f32_data, "main" , 5, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); |
| 3465 | ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_f32" , mul_mat_vec_id_q6_k_f32_len, mul_mat_vec_id_q6_k_f32_data, "main" , 5, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); |
| 3466 | ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ1_S], "mul_mat_vec_id_iq1_s_f32" , mul_mat_vec_id_iq1_s_f32_len, mul_mat_vec_id_iq1_s_f32_data, "main" , 5, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); |
| 3467 | ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ1_M], "mul_mat_vec_id_iq1_m_f32" , mul_mat_vec_id_iq1_m_f32_len, mul_mat_vec_id_iq1_m_f32_data, "main" , 5, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); |
| 3468 | ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_XXS], "mul_mat_vec_id_iq2_xxs_f32" , mul_mat_vec_id_iq2_xxs_f32_len, mul_mat_vec_id_iq2_xxs_f32_data, "main" , 5, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); |
| 3469 | ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_XS], "mul_mat_vec_id_iq2_xs_f32" , mul_mat_vec_id_iq2_xs_f32_len, mul_mat_vec_id_iq2_xs_f32_data, "main" , 5, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); |
| 3470 | ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_S], "mul_mat_vec_id_iq2_s_f32" , mul_mat_vec_id_iq2_s_f32_len, mul_mat_vec_id_iq2_s_f32_data, "main" , 5, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); |
| 3471 | ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_XXS], "mul_mat_vec_id_iq3_xxs_f32" , mul_mat_vec_id_iq3_xxs_f32_len, mul_mat_vec_id_iq3_xxs_f32_data, "main" , 5, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); |
| 3472 | ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_S], "mul_mat_vec_id_iq3_s_f32" , mul_mat_vec_id_iq3_s_f32_len, mul_mat_vec_id_iq3_s_f32_data, "main" , 5, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); |
| 3473 | ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_XS], "mul_mat_vec_id_iq4_xs_f32" , mul_mat_vec_id_iq4_xs_f32_len, mul_mat_vec_id_iq4_xs_f32_data, "main" , 5, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); |
| 3474 | ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32" , mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main" , 5, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); |
| 3475 | ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_MXFP4], "mul_mat_vec_id_mxfp4_f32" , mul_mat_vec_id_mxfp4_f32_len, mul_mat_vec_id_mxfp4_f32_data, "main" , 5, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); |
| 3476 | |
| 3477 | // dequant shaders |
| 3478 | ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16" , dequant_f32_len, dequant_f32_data, "main" , 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); |
| 3479 | ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q4_0], "dequant_q4_0" , dequant_q4_0_len, dequant_q4_0_data, "main" , 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); |
| 3480 | ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q4_1], "dequant_q4_1" , dequant_q4_1_len, dequant_q4_1_data, "main" , 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); |
| 3481 | ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q5_0], "dequant_q5_0" , dequant_q5_0_len, dequant_q5_0_data, "main" , 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); |
| 3482 | ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q5_1], "dequant_q5_1" , dequant_q5_1_len, dequant_q5_1_data, "main" , 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); |
| 3483 | ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q8_0], "dequant_q8_0" , dequant_q8_0_len, dequant_q8_0_data, "main" , 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); |
| 3484 | ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q2_K], "dequant_q2_k" , dequant_q2_k_len, dequant_q2_k_data, "main" , 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1); |
| 3485 | ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q3_K], "dequant_q3_k" , dequant_q3_k_len, dequant_q3_k_data, "main" , 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1); |
| 3486 | ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q4_K], "dequant_q4_k" , dequant_q4_k_len, dequant_q4_k_data, "main" , 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); |
| 3487 | ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q5_K], "dequant_q5_k" , dequant_q5_k_len, dequant_q5_k_data, "main" , 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1); |
| 3488 | ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q6_K], "dequant_q6_k" , dequant_q6_k_len, dequant_q6_k_data, "main" , 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1); |
| 3489 | ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ1_S], "dequant_iq1_s" , dequant_iq1_s_len, dequant_iq1_s_data, "main" , 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); |
| 3490 | ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ1_M], "dequant_iq1_m" , dequant_iq1_m_len, dequant_iq1_m_data, "main" , 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); |
| 3491 | ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ2_XXS], "dequant_iq2_xxs" , dequant_iq2_xxs_len, dequant_iq2_xxs_data, "main" , 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); |
| 3492 | ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ2_XS], "dequant_iq2_xs" , dequant_iq2_xs_len, dequant_iq2_xs_data, "main" , 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); |
| 3493 | ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ2_S], "dequant_iq2_s" , dequant_iq2_s_len, dequant_iq2_s_data, "main" , 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); |
| 3494 | ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ3_XXS], "dequant_iq3_xxs" , dequant_iq3_xxs_len, dequant_iq3_xxs_data, "main" , 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); |
| 3495 | ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ3_S], "dequant_iq3_s" , dequant_iq3_s_len, dequant_iq3_s_data, "main" , 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); |
| 3496 | ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_XS], "dequant_iq4_xs" , dequant_iq4_xs_len, dequant_iq4_xs_data, "main" , 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); |
| 3497 | ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_NL], "dequant_iq4_nl" , dequant_iq4_nl_len, dequant_iq4_nl_data, "main" , 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); |
| 3498 | ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_MXFP4], "dequant_mxfp4" , dequant_mxfp4_len, dequant_mxfp4_data, "main" , 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); |
| 3499 | |
| 3500 | // get_rows |
| 3501 | ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F32 ], "get_rows_f32" , get_rows_f32_len, get_rows_f32_data, "main" , 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); |
| 3502 | ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F16 ], "get_rows_f16" , get_rows_f16_len, get_rows_f16_data, "main" , 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); |
| 3503 | ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_BF16], "get_rows_bf16" , get_rows_bf16_len, get_rows_bf16_data, "main" , 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); |
| 3504 | ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_0], "get_rows_q4_0" , get_rows_q4_0_len, get_rows_q4_0_data, "main" , 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); |
| 3505 | ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_1], "get_rows_q4_1" , get_rows_q4_1_len, get_rows_q4_1_data, "main" , 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); |
| 3506 | ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_0], "get_rows_q5_0" , get_rows_q5_0_len, get_rows_q5_0_data, "main" , 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); |
| 3507 | ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_1], "get_rows_q5_1" , get_rows_q5_1_len, get_rows_q5_1_data, "main" , 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); |
| 3508 | ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q8_0], "get_rows_q8_0" , get_rows_q8_0_len, get_rows_q8_0_data, "main" , 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); |
| 3509 | ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q2_K], "get_rows_q2_k" , get_rows_q2_k_len, get_rows_q2_k_data, "main" , 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); |
| 3510 | ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q3_K], "get_rows_q3_k" , get_rows_q3_k_len, get_rows_q3_k_data, "main" , 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); |
| 3511 | ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_K], "get_rows_q4_k" , get_rows_q4_k_len, get_rows_q4_k_data, "main" , 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); |
| 3512 | ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_K], "get_rows_q5_k" , get_rows_q5_k_len, get_rows_q5_k_data, "main" , 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); |
| 3513 | ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q6_K], "get_rows_q6_k" , get_rows_q6_k_len, get_rows_q6_k_data, "main" , 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); |
| 3514 | ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ1_S], "get_rows_iq1_s" , get_rows_iq1_s_len, get_rows_iq1_s_data, "main" , 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); |
| 3515 | ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ1_M], "get_rows_iq1_m" , get_rows_iq1_m_len, get_rows_iq1_m_data, "main" , 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); |
| 3516 | ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ2_XXS], "get_rows_iq2_xxs" , get_rows_iq2_xxs_len, get_rows_iq2_xxs_data, "main" , 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); |
| 3517 | ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ2_XS], "get_rows_iq2_xs" , get_rows_iq2_xs_len, get_rows_iq2_xs_data, "main" , 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); |
| 3518 | ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ2_S], "get_rows_iq2_s" , get_rows_iq2_s_len, get_rows_iq2_s_data, "main" , 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); |
| 3519 | ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ3_XXS], "get_rows_iq3_xxs" , get_rows_iq3_xxs_len, get_rows_iq3_xxs_data, "main" , 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); |
| 3520 | ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ3_S], "get_rows_iq3_s" , get_rows_iq3_s_len, get_rows_iq3_s_data, "main" , 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); |
| 3521 | ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_XS], "get_rows_iq4_xs" , get_rows_iq4_xs_len, get_rows_iq4_xs_data, "main" , 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); |
| 3522 | ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl" , get_rows_iq4_nl_len, get_rows_iq4_nl_data, "main" , 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); |
| 3523 | ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_MXFP4], "get_rows_mxfp4" , get_rows_mxfp4_len, get_rows_mxfp4_data, "main" , 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); |
| 3524 | |
| 3525 | ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f32_f32" , get_rows_f32_f32_len, get_rows_f32_f32_data, "main" , 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); |
| 3526 | ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F16 ], "get_rows_f16_f32" , get_rows_f16_f32_len, get_rows_f16_f32_data, "main" , 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); |
| 3527 | ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_BF16], "get_rows_bf16_f32" , get_rows_bf16_f32_len, get_rows_bf16_f32_data, "main" , 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); |
| 3528 | ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_0], "get_rows_q4_0_f32" , get_rows_q4_0_f32_len, get_rows_q4_0_f32_data, "main" , 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); |
| 3529 | ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_1], "get_rows_q4_1_f32" , get_rows_q4_1_f32_len, get_rows_q4_1_f32_data, "main" , 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); |
| 3530 | ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_0], "get_rows_q5_0_f32" , get_rows_q5_0_f32_len, get_rows_q5_0_f32_data, "main" , 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); |
| 3531 | ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_1], "get_rows_q5_1_f32" , get_rows_q5_1_f32_len, get_rows_q5_1_f32_data, "main" , 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); |
| 3532 | ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q8_0], "get_rows_q8_0_f32" , get_rows_q8_0_f32_len, get_rows_q8_0_f32_data, "main" , 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); |
| 3533 | ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q2_K], "get_rows_q2_k_f32" , get_rows_q2_k_f32_len, get_rows_q2_k_f32_data, "main" , 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); |
| 3534 | ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q3_K], "get_rows_q3_k_f32" , get_rows_q3_k_f32_len, get_rows_q3_k_f32_data, "main" , 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); |
| 3535 | ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_K], "get_rows_q4_k_f32" , get_rows_q4_k_f32_len, get_rows_q4_k_f32_data, "main" , 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); |
| 3536 | ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_K], "get_rows_q5_k_f32" , get_rows_q5_k_f32_len, get_rows_q5_k_f32_data, "main" , 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); |
| 3537 | ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q6_K], "get_rows_q6_k_f32" , get_rows_q6_k_f32_len, get_rows_q6_k_f32_data, "main" , 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); |
| 3538 | ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ1_S], "get_rows_iq1_s_f32" , get_rows_iq1_s_f32_len, get_rows_iq1_s_f32_data, "main" , 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); |
| 3539 | ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ1_M], "get_rows_iq1_m_f32" , get_rows_iq1_m_f32_len, get_rows_iq1_m_f32_data, "main" , 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); |
| 3540 | ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ2_XXS], "get_rows_iq2_xxs_f32" , get_rows_iq2_xxs_f32_len, get_rows_iq2_xxs_f32_data, "main" , 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); |
| 3541 | ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ2_XS], "get_rows_iq2_xs_f32" , get_rows_iq2_xs_f32_len, get_rows_iq2_xs_f32_data, "main" , 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); |
| 3542 | ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ2_S], "get_rows_iq2_s_f32" , get_rows_iq2_s_f32_len, get_rows_iq2_s_f32_data, "main" , 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); |
| 3543 | ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ3_XXS], "get_rows_iq3_xxs_f32" , get_rows_iq3_xxs_f32_len, get_rows_iq3_xxs_f32_data, "main" , 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); |
| 3544 | ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ3_S], "get_rows_iq3_s_f32" , get_rows_iq3_s_f32_len, get_rows_iq3_s_f32_data, "main" , 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); |
| 3545 | ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_XS], "get_rows_iq4_xs_f32" , get_rows_iq4_xs_f32_len, get_rows_iq4_xs_f32_data, "main" , 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); |
| 3546 | ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32" , get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main" , 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); |
| 3547 | ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_MXFP4], "get_rows_mxfp4_f32" , get_rows_mxfp4_f32_len, get_rows_mxfp4_f32_data, "main" , 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); |
| 3548 | |
| 3549 | ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce" , split_k_reduce_len, split_k_reduce_data, "main" , 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1); |
| 3550 | ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce" , fa_split_k_reduce_len, fa_split_k_reduce_data, "main" , 3, 5 * sizeof(uint32_t), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true); |
| 3551 | |
| 3552 | if (device->subgroup_clustered && device->subgroup_require_full_support) { |
| 3553 | ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1, "quantize_q8_1" , quantize_q8_1_subgroup_len, quantize_q8_1_subgroup_data, "main" , 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1, true, true); |
| 3554 | ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1_x4, "quantize_q8_1_x4" , quantize_q8_1_x4_subgroup_len, quantize_q8_1_x4_subgroup_data, "main" , 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1, true, true); |
| 3555 | } else { |
| 3556 | ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1, "quantize_q8_1" , quantize_q8_1_len, quantize_q8_1_data, "main" , 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1); |
| 3557 | ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1_x4, "quantize_q8_1_x4" , quantize_q8_1_x4_len, quantize_q8_1_x4_data, "main" , 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1); |
| 3558 | } |
| 3559 | |
| 3560 | for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) { |
| 3561 | if (device->subgroup_arithmetic && device->subgroup_require_full_support) { |
| 3562 | ggml_vk_create_pipeline2(device, device->pipeline_mul_mat_vec_p021_f16_f32[i], "mul_mat_vec_p021_f16_f32" +std::to_string(val: i+1), mul_mat_vec_p021_f16_f32_subgroup_add_len, mul_mat_vec_p021_f16_f32_subgroup_add_data, "main" , 4, 7 * sizeof(uint32_t), {1, 1, 1}, {device->subgroup_size, i + 1}, 1, true, true); |
| 3563 | } else { |
| 3564 | ggml_vk_create_pipeline2(device, device->pipeline_mul_mat_vec_p021_f16_f32[i], "mul_mat_vec_p021_f16_f32" +std::to_string(val: i+1), mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main" , 4, 7 * sizeof(uint32_t), {1, 1, 1}, {device->subgroup_size, i + 1}, 1, true); |
| 3565 | } |
| 3566 | } |
| 3567 | ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32" , mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main" , 4, 13 * sizeof(uint32_t), {1, 1, 1}, {}, 1); |
| 3568 | |
| 3569 | ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32" , norm_f32_len, norm_f32_data, "main" , 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); |
| 3570 | ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32" , group_norm_f32_len, group_norm_f32_data, "main" , 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); |
| 3571 | |
| 3572 | ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32" , rms_norm_f32_len, rms_norm_f32_data, "main" , 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1, true); |
| 3573 | ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_f32, "rms_norm_mul_f32" , rms_norm_f32_len, rms_norm_f32_data, "main" , 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1, true); |
| 3574 | ggml_vk_create_pipeline(device, device->pipeline_rms_norm_partials_f32, "rms_norm_partials_f32" , rms_norm_partials_f32_len, rms_norm_partials_f32_data, "main" , 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1, true); |
| 3575 | ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_partials_f32, "rms_norm_mul_partials_f32" , rms_norm_partials_f32_len, rms_norm_partials_f32_data, "main" , 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1, true); |
| 3576 | |
| 3577 | if (device->float_controls_rte_fp16 && |
| 3578 | sizeof(vk_op_rms_norm_mul_rope_push_constants) <= device->properties.limits.maxPushConstantsSize) { |
| 3579 | ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_rope_f32_f32, "rms_norm_mul_rope_f32_f32" , rms_norm_mul_rope_f32_f32_len, rms_norm_mul_rope_f32_f32_data, "main" , 7, sizeof(vk_op_rms_norm_mul_rope_push_constants), {1, 1, 1}, {0, 1}, 1, true); |
| 3580 | ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_rope_f32_f16, "rms_norm_mul_rope_f32_f16" , rms_norm_mul_rope_f32_f16_rte_len, rms_norm_mul_rope_f32_f16_rte_data, "main" , 7, sizeof(vk_op_rms_norm_mul_rope_push_constants), {1, 1, 1}, {0, 1}, 1, true); |
| 3581 | } |
| 3582 | |
| 3583 | ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32" , rms_norm_back_f32_len, rms_norm_back_f32_data, "main" , 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); |
| 3584 | ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32" , l2_norm_f32_len, l2_norm_f32_data, "main" , 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); |
| 3585 | |
| 3586 | ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32" , cpy_f32_f32_len, cpy_f32_f32_data, "main" , 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); |
| 3587 | ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16" , cpy_f32_f16_len, cpy_f32_f16_data, "main" , 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); |
| 3588 | ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f16, "cpy_f16_f16" , cpy_f16_f16_len, cpy_f16_f16_data, "main" , 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); |
| 3589 | ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f32, "cpy_f16_f32" , cpy_f16_f32_len, cpy_f16_f32_data, "main" , 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); |
| 3590 | ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_bf16,"cpy_f32_bf16" ,cpy_f32_bf16_len,cpy_f32_bf16_data,"main" , 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); |
| 3591 | ggml_vk_create_pipeline(device, device->pipeline_cpy_i32_f32, "cpy_i32_f32" , cpy_i32_f32_len, cpy_i32_f32_data, "main" , 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); |
| 3592 | ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_i32, "cpy_f32_i32" , cpy_f32_i32_len, cpy_f32_i32_data, "main" , 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); |
| 3593 | |
| 3594 | ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f32, "contig_cpy_f32_f32" , contig_cpy_f32_f32_len, contig_cpy_f32_f32_data, "main" , 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); |
| 3595 | ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f16, "contig_cpy_f32_f16" , contig_cpy_f32_f16_len, contig_cpy_f32_f16_data, "main" , 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); |
| 3596 | ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f16, "contig_cpy_f16_f16" , contig_cpy_f16_f16_len, contig_cpy_f16_f16_data, "main" , 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); |
| 3597 | ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f32, "contig_cpy_f16_f32" , contig_cpy_f16_f32_len, contig_cpy_f16_f32_data, "main" , 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); |
| 3598 | ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_bf16,"contig_cpy_f32_bf16" ,contig_cpy_f32_bf16_len,contig_cpy_f32_bf16_data,"main" , 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); |
| 3599 | ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_i32_f32, "contig_cpy_i32_f32" , contig_cpy_i32_f32_len, contig_cpy_i32_f32_data, "main" , 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); |
| 3600 | ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_i32, "contig_cpy_f32_i32" , contig_cpy_f32_i32_len, contig_cpy_f32_i32_data, "main" , 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); |
| 3601 | |
| 3602 | if (device->float_controls_rte_fp16) { |
| 3603 | ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0" , cpy_f32_q4_0_rte_len, cpy_f32_q4_0_rte_data, "main" , 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); |
| 3604 | ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1" , cpy_f32_q4_1_rte_len, cpy_f32_q4_1_rte_data, "main" , 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); |
| 3605 | ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0" , cpy_f32_q5_0_rte_len, cpy_f32_q5_0_rte_data, "main" , 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); |
| 3606 | ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1" , cpy_f32_q5_1_rte_len, cpy_f32_q5_1_rte_data, "main" , 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); |
| 3607 | ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0" , cpy_f32_q8_0_rte_len, cpy_f32_q8_0_rte_data, "main" , 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); |
| 3608 | ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl" , cpy_f32_iq4_nl_rte_len, cpy_f32_iq4_nl_rte_data, "main" , 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); |
| 3609 | } else { |
| 3610 | ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0" , cpy_f32_q4_0_len, cpy_f32_q4_0_data, "main" , 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); |
| 3611 | ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1" , cpy_f32_q4_1_len, cpy_f32_q4_1_data, "main" , 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); |
| 3612 | ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0" , cpy_f32_q5_0_len, cpy_f32_q5_0_data, "main" , 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); |
| 3613 | ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1" , cpy_f32_q5_1_len, cpy_f32_q5_1_data, "main" , 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); |
| 3614 | ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0" , cpy_f32_q8_0_len, cpy_f32_q8_0_data, "main" , 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); |
| 3615 | ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl" , cpy_f32_iq4_nl_len, cpy_f32_iq4_nl_data, "main" , 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); |
| 3616 | } |
| 3617 | |
| 3618 | #define SET_ROWS(itype, rte) \ |
| 3619 | ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_F32], "set_rows_f32" #itype, set_rows_f32 ## itype ## rte ## _len, set_rows_f32 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ |
| 3620 | ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_F16], "set_rows_f16" #itype, set_rows_f16 ## itype ## rte ## _len, set_rows_f16 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ |
| 3621 | ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_BF16], "set_rows_bf16" #itype, set_rows_bf16 ## itype ## rte ## _len, set_rows_bf16 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ |
| 3622 | ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q4_0], "set_rows_q4_0" #itype, set_rows_q4_0 ## itype ## rte ## _len, set_rows_q4_0 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ |
| 3623 | ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q4_1], "set_rows_q4_1" #itype, set_rows_q4_1 ## itype ## rte ## _len, set_rows_q4_1 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ |
| 3624 | ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q5_0], "set_rows_q5_0" #itype, set_rows_q5_0 ## itype ## rte ## _len, set_rows_q5_0 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ |
| 3625 | ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q5_1], "set_rows_q5_1" #itype, set_rows_q5_1 ## itype ## rte ## _len, set_rows_q5_1 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ |
| 3626 | ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q8_0], "set_rows_q8_0" #itype, set_rows_q8_0 ## itype ## rte ## _len, set_rows_q8_0 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ |
| 3627 | ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_IQ4_NL], "set_rows_iq4_nl" #itype, set_rows_iq4_nl ## itype ## rte ## _len, set_rows_iq4_nl ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); |
| 3628 | |
| 3629 | if (device->float_controls_rte_fp16) { |
| 3630 | SET_ROWS(_i32, _rte) |
| 3631 | SET_ROWS(_i64, _rte) |
| 3632 | } else { |
| 3633 | SET_ROWS(_i32, ) |
| 3634 | SET_ROWS(_i64, ) |
| 3635 | } |
| 3636 | #undef SET_ROWS |
| 3637 | |
| 3638 | |
| 3639 | ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q4_0], "cpy_q4_0_f32" , cpy_q4_0_f32_len, cpy_q4_0_f32_data, "main" , 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(type: GGML_TYPE_Q4_0), 1, 1}, {}, 1); |
| 3640 | ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q4_1], "cpy_q4_1_f32" , cpy_q4_1_f32_len, cpy_q4_1_f32_data, "main" , 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(type: GGML_TYPE_Q4_1), 1, 1}, {}, 1); |
| 3641 | ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q5_0], "cpy_q5_0_f32" , cpy_q5_0_f32_len, cpy_q5_0_f32_data, "main" , 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(type: GGML_TYPE_Q5_0), 1, 1}, {}, 1); |
| 3642 | ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q5_1], "cpy_q5_1_f32" , cpy_q5_1_f32_len, cpy_q5_1_f32_data, "main" , 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(type: GGML_TYPE_Q5_1), 1, 1}, {}, 1); |
| 3643 | ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q8_0], "cpy_q8_0_f32" , cpy_q8_0_f32_len, cpy_q8_0_f32_data, "main" , 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(type: GGML_TYPE_Q8_0), 1, 1}, {}, 1); |
| 3644 | ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_IQ4_NL], "cpy_iq4_nl_f32" , cpy_iq4_nl_f32_len, cpy_iq4_nl_f32_data, "main" , 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(type: GGML_TYPE_IQ4_NL), 1, 1}, {}, 1); |
| 3645 | |
| 3646 | auto get_suffix = [](bool src0_f16, bool src1_f16, bool dst_f16) { |
| 3647 | std::string s; |
| 3648 | s += std::string(src0_f16 ? "_f16" : "_f32" ); |
| 3649 | s += std::string(src1_f16 ? "_f16" : "_f32" ); |
| 3650 | s += std::string(dst_f16 ? "_f16" : "_f32" ); |
| 3651 | return s; |
| 3652 | }; |
| 3653 | |
| 3654 | bool rte = device->float_controls_rte_fp16; |
| 3655 | #define CREATE_BINARY(name, namemod, spec, bindings) \ |
| 3656 | for (int s0 : {0,1}) for (int s1 : {0,1}) for (int d : {0,1}) \ |
| 3657 | ggml_vk_create_pipeline2(device, device->pipeline_ ## name ## namemod[s0][s1][d], \ |
| 3658 | #name + get_suffix(s0, s1, d) + #namemod, name ## _len[s0][s1][d][rte], name ## _data[s0][s1][d][rte], \ |
| 3659 | "main", (bindings), sizeof(vk_op_binary_push_constants), {512, 1, 1}, spec, 1); |
| 3660 | |
| 3661 | CREATE_BINARY(add, , {0}, 4) |
| 3662 | CREATE_BINARY(add, _norepeat, {1}, 4) |
| 3663 | CREATE_BINARY(sub, , {0}, 3) |
| 3664 | CREATE_BINARY(sub, _norepeat, {1}, 3) |
| 3665 | CREATE_BINARY(mul, , {0}, 3) |
| 3666 | CREATE_BINARY(mul, _norepeat, {1}, 3) |
| 3667 | CREATE_BINARY(div, , {0}, 3) |
| 3668 | CREATE_BINARY(div, _norepeat, {1}, 3) |
| 3669 | CREATE_BINARY(add_rms, , {0}, 4) |
| 3670 | CREATE_BINARY(add_rms, _norepeat, {1}, 4) |
| 3671 | #undef CREATE_BINARY |
| 3672 | |
| 3673 | if (device->multi_add) { |
| 3674 | for (uint32_t i = 0; i < MAX_FUSED_ADDS; ++i) { |
| 3675 | ggml_vk_create_pipeline2(device, device->pipeline_multi_add[i], "multi_add_f32_" + std::to_string(val: i+1), multi_add_f32_len, multi_add_f32_data, "main" , MAX_PARAMETER_COUNT, sizeof(vk_op_multi_add_push_constants), {512, 1, 1}, {i+2}, 1); |
| 3676 | ggml_vk_create_pipeline2(device, device->pipeline_multi_add_rms[i], "multi_add_rms_f32_" + std::to_string(val: i+1), multi_add_rms_f32_len, multi_add_rms_f32_data, "main" , MAX_PARAMETER_COUNT, sizeof(vk_op_multi_add_push_constants), {512, 1, 1}, {i+2}, 1); |
| 3677 | } |
| 3678 | } |
| 3679 | |
| 3680 | ggml_vk_create_pipeline(device, device->pipeline_add_id_f32, "add_id_f32" , add_id_f32_len, add_id_f32_data, "main" , 4, sizeof(vk_op_add_id_push_constants), {1, 1, 1}, {}, 1); |
| 3681 | |
| 3682 | ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32" , acc_f32_len, acc_f32_data, "main" , 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); |
| 3683 | |
| 3684 | ggml_vk_create_pipeline(device, device->pipeline_concat_f32, "concat_f32" , concat_f32_len, concat_f32_data, "main" , 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); |
| 3685 | ggml_vk_create_pipeline(device, device->pipeline_concat_f16, "concat_f16" , concat_f16_len, concat_f16_data, "main" , 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); |
| 3686 | ggml_vk_create_pipeline(device, device->pipeline_concat_i32, "concat_i32" , concat_i32_len, concat_i32_data, "main" , 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); |
| 3687 | |
| 3688 | ggml_vk_create_pipeline(device, device->pipeline_upscale_nearest_f32, "upscale_f32" , upscale_f32_len, upscale_f32_data, "main" , 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_NEAREST}, 1); |
| 3689 | ggml_vk_create_pipeline(device, device->pipeline_upscale_bilinear_f32, "upscale_f32" , upscale_f32_len, upscale_f32_data, "main" , 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BILINEAR}, 1); |
| 3690 | |
| 3691 | ggml_vk_create_pipeline(device, device->pipeline_scale_f32, "scale_f32" , scale_f32_len, scale_f32_data, "main" , 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); |
| 3692 | |
| 3693 | ggml_vk_create_pipeline(device, device->pipeline_sqr_f32, "sqr_f32" , sqr_f32_len, sqr_f32_data, "main" , 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); |
| 3694 | ggml_vk_create_pipeline(device, device->pipeline_sqrt_f32, "sqrt_f32" , sqrt_f32_len, sqrt_f32_data, "main" , 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); |
| 3695 | ggml_vk_create_pipeline(device, device->pipeline_sin_f32, "sin_f32" , sin_f32_len, sin_f32_data, "main" , 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); |
| 3696 | ggml_vk_create_pipeline(device, device->pipeline_cos_f32, "cos_f32" , cos_f32_len, cos_f32_data, "main" , 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); |
| 3697 | |
| 3698 | ggml_vk_create_pipeline(device, device->pipeline_clamp_f32, "clamp_f32" , clamp_f32_len, clamp_f32_data, "main" , 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); |
| 3699 | |
| 3700 | ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32" , pad_f32_len, pad_f32_data, "main" , 2, sizeof(vk_op_pad_push_constants), {512, 1, 1}, {}, 1); |
| 3701 | |
| 3702 | ggml_vk_create_pipeline(device, device->pipeline_roll_f32, "roll_f32" , roll_f32_len, roll_f32_data, "main" , 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); |
| 3703 | |
| 3704 | ggml_vk_create_pipeline(device, device->pipeline_repeat_f32, "repeat_f32" , repeat_f32_len, repeat_f32_data, "main" , 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); |
| 3705 | ggml_vk_create_pipeline(device, device->pipeline_repeat_back_f32, "repeat_back_f32" , repeat_back_f32_len, repeat_back_f32_data, "main" , 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); |
| 3706 | |
| 3707 | #define CREATE_UNARY(name) \ |
| 3708 | ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \ |
| 3709 | ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); |
| 3710 | |
| 3711 | CREATE_UNARY(gelu) |
| 3712 | CREATE_UNARY(gelu_erf) |
| 3713 | CREATE_UNARY(gelu_quick) |
| 3714 | CREATE_UNARY(silu) |
| 3715 | CREATE_UNARY(relu) |
| 3716 | CREATE_UNARY(tanh) |
| 3717 | CREATE_UNARY(sigmoid) |
| 3718 | CREATE_UNARY(hardsigmoid) |
| 3719 | CREATE_UNARY(hardswish) |
| 3720 | #undef CREATE_UNARY |
| 3721 | |
| 3722 | #define CREATE_UNARY_RTE(name) \ |
| 3723 | if (device->float_controls_rte_fp16) { \ |
| 3724 | ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32_rte", name ## _f32_rte_len, name ## _f32_rte_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \ |
| 3725 | ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16_rte", name ## _f16_rte_len, name ## _f16_rte_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \ |
| 3726 | } else { \ |
| 3727 | ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \ |
| 3728 | ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \ |
| 3729 | } |
| 3730 | CREATE_UNARY_RTE(exp) |
| 3731 | #undef CREATE_UNARY_RTE |
| 3732 | |
| 3733 | #define CREATE_GLU(name) \ |
| 3734 | if (device->float_controls_rte_fp16) { \ |
| 3735 | ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32_rte", name ## _f32_rte_len, name ## _f32_rte_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \ |
| 3736 | ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16_rte", name ## _f16_rte_len, name ## _f16_rte_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \ |
| 3737 | } else { \ |
| 3738 | ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \ |
| 3739 | ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \ |
| 3740 | } |
| 3741 | |
| 3742 | CREATE_GLU(geglu) |
| 3743 | CREATE_GLU(reglu) |
| 3744 | CREATE_GLU(swiglu) |
| 3745 | CREATE_GLU(swiglu_oai) |
| 3746 | CREATE_GLU(geglu_erf) |
| 3747 | CREATE_GLU(geglu_quick) |
| 3748 | #undef CREATE_GLU |
| 3749 | |
| 3750 | ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32" , leaky_relu_f32_len, leaky_relu_f32_data, "main" , 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); |
| 3751 | ggml_vk_create_pipeline(device, device->pipeline_silu_back_f32, "silu_back_f32" , silu_back_f32_len, silu_back_f32_data, "main" , 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); |
| 3752 | |
| 3753 | ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32" , diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main" , 2, sizeof(vk_op_diag_mask_push_constants), {1, 512, 1}, {}, 1, true); |
| 3754 | |
| 3755 | ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32, "soft_max_f32" , soft_max_f32_len, soft_max_f32_data, "main" , 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); |
| 3756 | ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_wg512, "soft_max_f32_wg512" , soft_max_f32_len, soft_max_f32_data, "main" , 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1); |
| 3757 | ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16, "soft_max_f32_f16" , soft_max_f32_f16_len, soft_max_f32_f16_data, "main" , 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); |
| 3758 | ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512" , soft_max_f32_f16_len, soft_max_f32_f16_data, "main" , 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1); |
| 3759 | ggml_vk_create_pipeline(device, device->pipeline_soft_max_back_f32, "soft_max_back_f32" , soft_max_back_f32_len, soft_max_back_f32_data, "main" , 3, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1, true); |
| 3760 | |
| 3761 | ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32" , rope_norm_f32_len, rope_norm_f32_data, "main" , 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); |
| 3762 | ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32, "rope_neox_f32" , rope_neox_f32_len, rope_neox_f32_data, "main" , 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); |
| 3763 | ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32, "rope_multi_f32" , rope_multi_f32_len, rope_multi_f32_data, "main" , 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); |
| 3764 | ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f32, "rope_vision_f32" , rope_vision_f32_len, rope_vision_f32_data, "main" , 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); |
| 3765 | |
| 3766 | if (device->float_controls_rte_fp16) { |
| 3767 | ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16" , rope_norm_f16_rte_len, rope_norm_f16_rte_data, "main" , 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); |
| 3768 | ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16" , rope_neox_f16_rte_len, rope_neox_f16_rte_data, "main" , 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); |
| 3769 | ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f16, "rope_multi_f16" , rope_multi_f16_rte_len, rope_multi_f16_rte_data, "main" , 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); |
| 3770 | ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16" , rope_vision_f16_rte_len, rope_vision_f16_rte_data, "main" , 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); |
| 3771 | |
| 3772 | ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32_f16, "rope_norm_f32_f16" , rope_norm_f32_f16_rte_len, rope_norm_f32_f16_rte_data, "main" , 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); |
| 3773 | ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32_f16, "rope_neox_f32_f16" , rope_neox_f32_f16_rte_len, rope_neox_f32_f16_rte_data, "main" , 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); |
| 3774 | } else { |
| 3775 | ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16" , rope_norm_f16_len, rope_norm_f16_data, "main" , 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); |
| 3776 | ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16" , rope_neox_f16_len, rope_neox_f16_data, "main" , 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); |
| 3777 | ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f16, "rope_multi_f16" , rope_multi_f16_len, rope_multi_f16_data, "main" , 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); |
| 3778 | ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16" , rope_vision_f16_len, rope_vision_f16_data, "main" , 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); |
| 3779 | |
| 3780 | ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32_f16, "rope_norm_f32_f16" , rope_norm_f32_f16_len, rope_norm_f32_f16_data, "main" , 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); |
| 3781 | ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32_f16, "rope_neox_f32_f16" , rope_neox_f32_f16_len, rope_neox_f32_f16_data, "main" , 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); |
| 3782 | } |
| 3783 | |
| 3784 | for (uint32_t i = 0; i < num_argsort_pipelines; ++i) { |
| 3785 | ggml_vk_create_pipeline2(device, device->pipeline_argsort_f32[i], "argsort_f32_" +std::to_string(val: i), argsort_f32_len, argsort_f32_data, "main" , 2, sizeof(vk_op_argsort_push_constants), {1u<<i, 1, 1}, {1u<<i, i}, 1, true); |
| 3786 | } |
| 3787 | |
| 3788 | ggml_vk_create_pipeline(device, device->pipeline_argmax_f32, "argmax_f32" , argmax_f32_len, argmax_f32_data, "main" , 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); |
| 3789 | |
| 3790 | ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32" , sum_rows_f32_len, sum_rows_f32_data, "main" , 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); |
| 3791 | |
| 3792 | ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32" , count_equal_i32_len, count_equal_i32_data, "main" , 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1); |
| 3793 | |
| 3794 | #define IM2COL(bda) \ |
| 3795 | ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32 ## bda ## _len, im2col_f32 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \ |
| 3796 | ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32, "im2col_3d_f32", im2col_3d_f32 ## bda ## _len, im2col_3d_f32 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \ |
| 3797 | if (device->float_controls_rte_fp16) { \ |
| 3798 | ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte ## bda ## _len, im2col_f32_f16_rte ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \ |
| 3799 | ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16_rte ## bda ## _len, im2col_3d_f32_f16_rte ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \ |
| 3800 | } else { \ |
| 3801 | ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16 ## bda ## _len, im2col_f32_f16 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \ |
| 3802 | ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16 ## bda ## _len, im2col_3d_f32_f16 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \ |
| 3803 | } |
| 3804 | if (device->shader_int64 && device->buffer_device_address) { |
| 3805 | IM2COL(_bda) |
| 3806 | } else { |
| 3807 | IM2COL() |
| 3808 | } |
| 3809 | |
| 3810 | ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32" , timestep_embedding_f32_len, timestep_embedding_f32_data, "main" , 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1); |
| 3811 | |
| 3812 | ggml_vk_create_pipeline(device, device->pipeline_conv_transpose_1d_f32, "conv_transpose_1d_f32" , conv_transpose_1d_f32_len, conv_transpose_1d_f32_data, "main" , 3, sizeof(vk_op_conv_transpose_1d_push_constants), {1, 1, 1}, {}, 1); |
| 3813 | |
| 3814 | ggml_vk_create_pipeline(device, device->pipeline_pool2d_f32, "pool2d_f32" , pool2d_f32_len, pool2d_f32_data, "main" , 2, sizeof(vk_op_pool2d_push_constants), {512, 1, 1}, {}, 1); |
| 3815 | |
| 3816 | ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32" , rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main" , 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1); |
| 3817 | |
| 3818 | ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32" , rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main" , 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1); |
| 3819 | |
| 3820 | if (device->subgroup_arithmetic && device->subgroup_require_full_support) { |
| 3821 | ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32" , ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main" , 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size, 16}, 1, true, true); |
| 3822 | ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32" , ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main" , 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1, true, true); |
| 3823 | } else { |
| 3824 | ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32" , ssm_scan_f32_len, ssm_scan_f32_data, "main" , 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size, 16}, 1, true, true); |
| 3825 | ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32" , ssm_scan_f32_len, ssm_scan_f32_data, "main" , 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1, true, true); |
| 3826 | } |
| 3827 | |
| 3828 | ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_f32, "ssm_conv_f32" , ssm_conv_f32_len, ssm_conv_f32_data, "main" , 3, sizeof(vk_op_ssm_conv_push_constants), {32, 1, 1}, {32}, 1); |
| 3829 | |
| 3830 | ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32" , opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main" , 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); |
| 3831 | |
| 3832 | ggml_vk_create_pipeline(device, device->pipeline_opt_step_sgd_f32, "opt_step_sgd_f32" , opt_step_sgd_f32_len, opt_step_sgd_f32_data, "main" , 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); |
| 3833 | |
| 3834 | // conv2d, conv_transpose_2d |
| 3835 | for (uint32_t s = 0; s < CONV_SHAPE_COUNT; ++s) { |
| 3836 | uint32_t conv2d_WG_SIZE = 256; |
| 3837 | uint32_t conv2d_BS_K = 128; |
| 3838 | uint32_t conv2d_BS_CRS = 16; |
| 3839 | uint32_t use_collectives = 0; // Enables subgroup ops for preventing the re-calculation of indices. |
| 3840 | uint32_t conv2d_BS_NPQ = 128; |
| 3841 | uint32_t conv2d_TS_K = 8; |
| 3842 | uint32_t conv2d_SHMEM_PAD = 4; |
| 3843 | bool conv2d_UNROLL = true; |
| 3844 | |
| 3845 | #if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) |
| 3846 | if (device->coopmat2) { |
| 3847 | conv2d_SHMEM_PAD = 8; // 8 float16_t |
| 3848 | } |
| 3849 | #endif |
| 3850 | |
| 3851 | if (device->vendor_id == VK_VENDOR_ID_INTEL) { |
| 3852 | conv2d_SHMEM_PAD = 0; |
| 3853 | conv2d_UNROLL = false; |
| 3854 | } else if (device->vendor_id == VK_VENDOR_ID_AMD) { |
| 3855 | conv2d_SHMEM_PAD = device->architecture == vk_device_architecture::AMD_GCN ? 1 : 4; |
| 3856 | } |
| 3857 | |
| 3858 | switch (s) { |
| 3859 | default: |
| 3860 | case CONV_SHAPE_128x128: |
| 3861 | conv2d_BS_K = 128; |
| 3862 | conv2d_BS_NPQ = 128; |
| 3863 | conv2d_BS_CRS = 16; |
| 3864 | if (device->vendor_id == VK_VENDOR_ID_AMD && device->architecture != vk_device_architecture::AMD_GCN) { |
| 3865 | conv2d_UNROLL = false; |
| 3866 | } |
| 3867 | break; |
| 3868 | case CONV_SHAPE_64x32: |
| 3869 | conv2d_BS_K = 64; |
| 3870 | conv2d_BS_NPQ = 32; |
| 3871 | conv2d_BS_CRS = 32; |
| 3872 | conv2d_TS_K = 4; |
| 3873 | break; |
| 3874 | case CONV_SHAPE_32x256: |
| 3875 | conv2d_BS_K = 32; |
| 3876 | conv2d_BS_NPQ = 256; |
| 3877 | conv2d_BS_CRS = 16; |
| 3878 | break; |
| 3879 | } |
| 3880 | |
| 3881 | // Use collectives on pre-Turing NVIDIA GPUs and GCN AMD cards, which had slower integer math. |
| 3882 | bool allow_collectives_nv = device->vendor_id != VK_VENDOR_ID_NVIDIA || |
| 3883 | device->architecture == vk_device_architecture::NVIDIA_PRE_TURING; |
| 3884 | bool allow_collectives_amd = device->vendor_id != VK_VENDOR_ID_AMD || |
| 3885 | device->architecture == vk_device_architecture::AMD_GCN; |
| 3886 | |
| 3887 | if (device->subgroup_shuffle && |
| 3888 | device->vendor_id != VK_VENDOR_ID_INTEL && // Do not enable collectives on Intel, see PR 14316. |
| 3889 | allow_collectives_nv && |
| 3890 | allow_collectives_amd) { |
| 3891 | use_collectives = 1; |
| 3892 | conv2d_BS_CRS = std::min( |
| 3893 | a: device->subgroup_size, |
| 3894 | b: conv2d_BS_CRS); // CRS block size should be capped at subgroup size for correctness when shuffle is used. |
| 3895 | } |
| 3896 | |
| 3897 | uint32_t conv2d_shmem_req = |
| 3898 | (conv2d_BS_K * (conv2d_BS_CRS + conv2d_SHMEM_PAD) + conv2d_BS_CRS * (conv2d_BS_NPQ + conv2d_SHMEM_PAD)) * sizeof(float); |
| 3899 | if (device->properties.limits.maxComputeSharedMemorySize < conv2d_shmem_req) { |
| 3900 | conv2d_BS_CRS = 8; |
| 3901 | if (use_collectives) { |
| 3902 | conv2d_BS_CRS = std::min(a: device->subgroup_size, b: conv2d_BS_CRS); |
| 3903 | } |
| 3904 | } |
| 3905 | |
| 3906 | std::array<uint32_t, 3> wg_denoms = { conv2d_BS_K, conv2d_BS_NPQ, 1 }; |
| 3907 | std::vector<uint32_t> spec_constants = { conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives, conv2d_SHMEM_PAD }; |
| 3908 | |
| 3909 | #define CREATE_CONV(name, type_suffix, spv_suffix) \ |
| 3910 | ggml_vk_create_pipeline( \ |
| 3911 | device, device->pipeline_##name##type_suffix[s], #name #type_suffix, \ |
| 3912 | name##type_suffix##spv_suffix##_len, name##type_suffix##spv_suffix##_data, "main", 3, \ |
| 3913 | sizeof(vk_op_##name##_push_constants), wg_denoms, spec_constants, 1, true, use_collectives); |
| 3914 | #define CREATE_CONVS(spv_suffix) \ |
| 3915 | CREATE_CONV(conv2d, _f32, spv_suffix) \ |
| 3916 | CREATE_CONV(conv2d, _f16_f32, spv_suffix) \ |
| 3917 | if (device->properties.limits.maxPushConstantsSize >= sizeof(vk_op_conv_transpose_2d_push_constants)) { \ |
| 3918 | CREATE_CONV(conv_transpose_2d, _f32, spv_suffix) \ |
| 3919 | CREATE_CONV(conv_transpose_2d, _f16_f32, spv_suffix) \ |
| 3920 | } |
| 3921 | #if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) |
| 3922 | if (device->coopmat2) { |
| 3923 | CREATE_CONVS(_cm2) |
| 3924 | } else |
| 3925 | #endif |
| 3926 | if (conv2d_UNROLL) { |
| 3927 | CREATE_CONVS(_unroll) |
| 3928 | } else { |
| 3929 | CREATE_CONVS( ) |
| 3930 | } |
| 3931 | #undef CREATE_CONV |
| 3932 | #undef CREATE_CONVS |
| 3933 | } |
| 3934 | |
| 3935 | ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32" , conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main" , 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1); |
| 3936 | ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f32, "conv2d_dw_cwhn_f32" , conv2d_dw_cwhn_f32_len, conv2d_dw_cwhn_f32_data, "main" , 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1); |
| 3937 | ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f16_f32, "conv2d_dw_whcn_f16_f32" , conv2d_dw_whcn_f16_f32_len, conv2d_dw_whcn_f16_f32_data, "main" , 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1); |
| 3938 | ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f16_f32, "conv2d_dw_cwhn_f16_f32" , conv2d_dw_cwhn_f16_f32_len, conv2d_dw_cwhn_f16_f32_data, "main" , 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1); |
| 3939 | |
| 3940 | for (uint32_t i = 0; i < num_topk_moe_pipelines; ++i) { |
| 3941 | ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX], "topk_moe_f32_early_softmax_" +std::to_string(val: i), topk_moe_f32_len, topk_moe_f32_data, "main" , 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0, 0}, 1, true, true); |
| 3942 | ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX_NORM], "topk_moe_f32_early_softmax_norm" +std::to_string(val: i), topk_moe_f32_len, topk_moe_f32_data, "main" , 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 1, 0}, 1, true, true); |
| 3943 | ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_LATE_SOFTMAX], "topk_moe_f32_late_softmax" +std::to_string(val: i), topk_moe_f32_len, topk_moe_f32_data, "main" , 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0, 1}, 1, true, true); |
| 3944 | } |
| 3945 | |
| 3946 | for (auto &c : compiles) { |
| 3947 | c.wait(); |
| 3948 | } |
| 3949 | } |
| 3950 | |
| 3951 | static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props, vk_device_architecture arch); |
| 3952 | |
| 3953 | static vk_device ggml_vk_get_device(size_t idx) { |
| 3954 | VK_LOG_DEBUG("ggml_vk_get_device(" << idx << ")" ); |
| 3955 | |
| 3956 | if (vk_instance.devices[idx] == nullptr) { |
| 3957 | VK_LOG_DEBUG("Initializing new vk_device" ); |
| 3958 | vk_device device = std::make_shared<vk_device_struct>(); |
| 3959 | vk_instance.devices[idx] = device; |
| 3960 | |
| 3961 | #ifdef GGML_VULKAN_MEMORY_DEBUG |
| 3962 | device->memory_logger = std::unique_ptr<vk_memory_logger>(new vk_memory_logger()); |
| 3963 | #endif |
| 3964 | if (vk_perf_logger_enabled) { |
| 3965 | device->perf_logger = std::unique_ptr<vk_perf_logger>(new vk_perf_logger()); |
| 3966 | } |
| 3967 | |
| 3968 | size_t dev_num = vk_instance.device_indices[idx]; |
| 3969 | |
| 3970 | std::vector<vk::PhysicalDevice> physical_devices = vk_instance.instance.enumeratePhysicalDevices(); |
| 3971 | |
| 3972 | if (dev_num >= physical_devices.size()) { |
| 3973 | std::cerr << "ggml_vulkan: Device with index " << dev_num << " does not exist." << std::endl; |
| 3974 | throw std::runtime_error("Device not found" ); |
| 3975 | } |
| 3976 | |
| 3977 | device->physical_device = physical_devices[dev_num]; |
| 3978 | const std::vector<vk::ExtensionProperties> ext_props = device->physical_device.enumerateDeviceExtensionProperties(); |
| 3979 | |
| 3980 | device->architecture = get_device_architecture(device: device->physical_device); |
| 3981 | |
| 3982 | const char* GGML_VK_PREFER_HOST_MEMORY = getenv(name: "GGML_VK_PREFER_HOST_MEMORY" ); |
| 3983 | device->prefer_host_memory = GGML_VK_PREFER_HOST_MEMORY != nullptr; |
| 3984 | |
| 3985 | const char* GGML_VK_DISABLE_HOST_VISIBLE_VIDMEM = getenv(name: "GGML_VK_DISABLE_HOST_VISIBLE_VIDMEM" ); |
| 3986 | device->disable_host_visible_vidmem = GGML_VK_DISABLE_HOST_VISIBLE_VIDMEM != nullptr; |
| 3987 | |
| 3988 | const char* GGML_VK_ALLOW_SYSMEM_FALLBACK = getenv(name: "GGML_VK_ALLOW_SYSMEM_FALLBACK" ); |
| 3989 | device->allow_sysmem_fallback = GGML_VK_ALLOW_SYSMEM_FALLBACK != nullptr; |
| 3990 | |
| 3991 | const char* GGML_VK_DISABLE_GRAPH_OPTIMIZE = getenv(name: "GGML_VK_DISABLE_GRAPH_OPTIMIZE" ); |
| 3992 | device->disable_graph_optimize = GGML_VK_DISABLE_GRAPH_OPTIMIZE != nullptr; |
| 3993 | |
| 3994 | bool fp16_storage = false; |
| 3995 | bool fp16_compute = false; |
| 3996 | bool maintenance4_support = false; |
| 3997 | bool sm_builtins = false; |
| 3998 | bool amd_shader_core_properties2 = false; |
| 3999 | bool pipeline_robustness = false; |
| 4000 | bool coopmat2_support = false; |
| 4001 | bool pipeline_executable_properties_support = false; |
| 4002 | device->coopmat_support = false; |
| 4003 | device->integer_dot_product = false; |
| 4004 | bool bfloat16_support = false; |
| 4005 | |
| 4006 | for (const auto& properties : ext_props) { |
| 4007 | if (strcmp(s1: "VK_KHR_maintenance4" , s2: properties.extensionName) == 0) { |
| 4008 | maintenance4_support = true; |
| 4009 | } else if (strcmp(s1: "VK_KHR_16bit_storage" , s2: properties.extensionName) == 0) { |
| 4010 | fp16_storage = true; |
| 4011 | } else if (strcmp(s1: "VK_KHR_shader_float16_int8" , s2: properties.extensionName) == 0) { |
| 4012 | fp16_compute = true; |
| 4013 | } else if (strcmp(s1: "VK_NV_shader_sm_builtins" , s2: properties.extensionName) == 0) { |
| 4014 | sm_builtins = true; |
| 4015 | } else if (strcmp(s1: "VK_AMD_shader_core_properties2" , s2: properties.extensionName) == 0) { |
| 4016 | amd_shader_core_properties2 = true; |
| 4017 | } else if (strcmp(s1: "VK_EXT_pipeline_robustness" , s2: properties.extensionName) == 0) { |
| 4018 | pipeline_robustness = true; |
| 4019 | } else if (strcmp(s1: "VK_EXT_subgroup_size_control" , s2: properties.extensionName) == 0) { |
| 4020 | device->subgroup_size_control = true; |
| 4021 | #if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) |
| 4022 | } else if (strcmp(s1: "VK_KHR_cooperative_matrix" , s2: properties.extensionName) == 0 && |
| 4023 | !getenv(name: "GGML_VK_DISABLE_COOPMAT" )) { |
| 4024 | device->coopmat_support = true; |
| 4025 | device->coopmat_m = 0; |
| 4026 | device->coopmat_n = 0; |
| 4027 | device->coopmat_k = 0; |
| 4028 | #endif |
| 4029 | #if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) |
| 4030 | } else if (strcmp(s1: "VK_NV_cooperative_matrix2" , s2: properties.extensionName) == 0 && |
| 4031 | !getenv(name: "GGML_VK_DISABLE_COOPMAT2" )) { |
| 4032 | coopmat2_support = true; |
| 4033 | #endif |
| 4034 | #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) |
| 4035 | } else if (strcmp(s1: "VK_KHR_shader_integer_dot_product" , s2: properties.extensionName) == 0 && |
| 4036 | !getenv(name: "GGML_VK_DISABLE_INTEGER_DOT_PRODUCT" )) { |
| 4037 | device->integer_dot_product = true; |
| 4038 | #endif |
| 4039 | #if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) |
| 4040 | } else if (strcmp(s1: "VK_KHR_shader_bfloat16" , s2: properties.extensionName) == 0 && |
| 4041 | !getenv(name: "GGML_VK_DISABLE_BFLOAT16" )) { |
| 4042 | bfloat16_support = true; |
| 4043 | #endif |
| 4044 | } else if (strcmp(s1: "VK_KHR_pipeline_executable_properties" , s2: properties.extensionName) == 0) { |
| 4045 | pipeline_executable_properties_support = true; |
| 4046 | } |
| 4047 | } |
| 4048 | |
| 4049 | vk::PhysicalDeviceProperties2 props2; |
| 4050 | vk::PhysicalDeviceMaintenance3Properties props3; |
| 4051 | vk::PhysicalDeviceMaintenance4Properties props4; |
| 4052 | vk::PhysicalDeviceSubgroupProperties subgroup_props; |
| 4053 | vk::PhysicalDeviceDriverProperties driver_props; |
| 4054 | vk::PhysicalDeviceShaderSMBuiltinsPropertiesNV sm_props; |
| 4055 | vk::PhysicalDeviceShaderCoreProperties2AMD amd_shader_core_properties2_props; |
| 4056 | vk::PhysicalDeviceVulkan11Properties vk11_props; |
| 4057 | vk::PhysicalDeviceVulkan12Properties vk12_props; |
| 4058 | vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props; |
| 4059 | vk::PhysicalDeviceShaderIntegerDotProductPropertiesKHR shader_integer_dot_product_props; |
| 4060 | |
| 4061 | props2.pNext = &props3; |
| 4062 | props3.pNext = &subgroup_props; |
| 4063 | subgroup_props.pNext = &driver_props; |
| 4064 | driver_props.pNext = &vk11_props; |
| 4065 | vk11_props.pNext = &vk12_props; |
| 4066 | |
| 4067 | VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&vk12_props; |
| 4068 | |
| 4069 | if (maintenance4_support) { |
| 4070 | last_struct->pNext = (VkBaseOutStructure *)&props4; |
| 4071 | last_struct = (VkBaseOutStructure *)&props4; |
| 4072 | } |
| 4073 | if (sm_builtins) { |
| 4074 | last_struct->pNext = (VkBaseOutStructure *)&sm_props; |
| 4075 | last_struct = (VkBaseOutStructure *)&sm_props; |
| 4076 | } |
| 4077 | if (amd_shader_core_properties2) { |
| 4078 | last_struct->pNext = (VkBaseOutStructure *)&amd_shader_core_properties2_props; |
| 4079 | last_struct = (VkBaseOutStructure *)&amd_shader_core_properties2_props; |
| 4080 | } |
| 4081 | if (device->subgroup_size_control) { |
| 4082 | last_struct->pNext = (VkBaseOutStructure *)&subgroup_size_control_props; |
| 4083 | last_struct = (VkBaseOutStructure *)&subgroup_size_control_props; |
| 4084 | } |
| 4085 | |
| 4086 | #if defined(VK_NV_cooperative_matrix2) |
| 4087 | vk::PhysicalDeviceCooperativeMatrix2PropertiesNV coopmat2_props; |
| 4088 | if (coopmat2_support) { |
| 4089 | last_struct->pNext = (VkBaseOutStructure *)&coopmat2_props; |
| 4090 | last_struct = (VkBaseOutStructure *)&coopmat2_props; |
| 4091 | } |
| 4092 | #endif |
| 4093 | |
| 4094 | if (device->integer_dot_product) { |
| 4095 | last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_props; |
| 4096 | last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_props; |
| 4097 | } |
| 4098 | |
| 4099 | device->physical_device.getProperties2(pProperties: &props2); |
| 4100 | device->properties = props2.properties; |
| 4101 | device->vendor_id = device->properties.vendorID; |
| 4102 | device->driver_id = driver_props.driverID; |
| 4103 | |
| 4104 | const char* GGML_VK_FORCE_MAX_ALLOCATION_SIZE = getenv(name: "GGML_VK_FORCE_MAX_ALLOCATION_SIZE" ); |
| 4105 | |
| 4106 | if (GGML_VK_FORCE_MAX_ALLOCATION_SIZE != nullptr) { |
| 4107 | device->max_memory_allocation_size = std::stoull(str: GGML_VK_FORCE_MAX_ALLOCATION_SIZE); |
| 4108 | } else if (maintenance4_support) { |
| 4109 | device->max_memory_allocation_size = std::min(a: props3.maxMemoryAllocationSize, b: props4.maxBufferSize); |
| 4110 | } else { |
| 4111 | device->max_memory_allocation_size = props3.maxMemoryAllocationSize; |
| 4112 | } |
| 4113 | |
| 4114 | const char* GGML_VK_FORCE_MAX_BUFFER_SIZE = getenv(name: "GGML_VK_FORCE_MAX_BUFFER_SIZE" ); |
| 4115 | |
| 4116 | if (GGML_VK_FORCE_MAX_BUFFER_SIZE != nullptr) { |
| 4117 | device->max_buffer_size = std::stoull(str: GGML_VK_FORCE_MAX_BUFFER_SIZE); |
| 4118 | } else if (maintenance4_support) { |
| 4119 | device->max_buffer_size = props4.maxBufferSize; |
| 4120 | } else { |
| 4121 | device->max_buffer_size = device->max_memory_allocation_size; |
| 4122 | } |
| 4123 | |
| 4124 | const char* GGML_VK_SUBALLOCATION_BLOCK_SIZE = getenv(name: "GGML_VK_SUBALLOCATION_BLOCK_SIZE" ); |
| 4125 | |
| 4126 | if (GGML_VK_SUBALLOCATION_BLOCK_SIZE != nullptr) { |
| 4127 | device->suballocation_block_size = std::stoull(str: GGML_VK_SUBALLOCATION_BLOCK_SIZE); |
| 4128 | } else { |
| 4129 | // Limit batching of allocations to 1GB by default to avoid fragmentation issues |
| 4130 | device->suballocation_block_size = 1024*1024*1024; |
| 4131 | } |
| 4132 | device->suballocation_block_size = std::min(a: device->suballocation_block_size, b: device->max_memory_allocation_size); |
| 4133 | |
| 4134 | device->subgroup_size = subgroup_props.subgroupSize; |
| 4135 | device->uma = device->properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu; |
| 4136 | if (sm_builtins) { |
| 4137 | device->shader_core_count = sm_props.shaderSMCount; |
| 4138 | } else if (amd_shader_core_properties2) { |
| 4139 | device->shader_core_count = amd_shader_core_properties2_props.activeComputeUnitCount; |
| 4140 | } else { |
| 4141 | device->shader_core_count = 0; |
| 4142 | } |
| 4143 | device->float_controls_rte_fp16 = vk12_props.shaderRoundingModeRTEFloat16; |
| 4144 | |
| 4145 | device->subgroup_arithmetic = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) && |
| 4146 | (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eArithmetic); |
| 4147 | #ifdef __APPLE__ |
| 4148 | // Workaround for subgroup arithmetic failing on MoltenVK with AMD GPUs (issue 15846) |
| 4149 | if (device->vendor_id == VK_VENDOR_ID_AMD) { |
| 4150 | device->subgroup_arithmetic = false; |
| 4151 | } |
| 4152 | #endif |
| 4153 | device->subgroup_shuffle = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) && |
| 4154 | (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eShuffle); |
| 4155 | device->subgroup_clustered = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) && |
| 4156 | (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eClustered); |
| 4157 | |
| 4158 | device->subgroup_ballot = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) && |
| 4159 | (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eBallot); |
| 4160 | |
| 4161 | const bool force_disable_f16 = getenv(name: "GGML_VK_DISABLE_F16" ) != nullptr; |
| 4162 | |
| 4163 | device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute; |
| 4164 | |
| 4165 | if (!ggml_vk_khr_cooperative_matrix_support(props: device->properties, driver_props, arch: device->architecture)) { |
| 4166 | device->coopmat_support = false; |
| 4167 | } |
| 4168 | |
| 4169 | device->integer_dot_product = device->integer_dot_product && shader_integer_dot_product_props.integerDotProduct4x8BitPackedSignedAccelerated; |
| 4170 | |
| 4171 | std::vector<vk::QueueFamilyProperties> queue_family_props = device->physical_device.getQueueFamilyProperties(); |
| 4172 | |
| 4173 | // Try to find a non-graphics compute queue and transfer-focused queues |
| 4174 | const uint32_t compute_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, required: vk::QueueFlagBits::eCompute, avoid: vk::QueueFlagBits::eGraphics, compute_index: -1, min_num_queues: 1); |
| 4175 | const uint32_t transfer_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, required: vk::QueueFlagBits::eTransfer, avoid: vk::QueueFlagBits::eCompute | vk::QueueFlagBits::eGraphics, compute_index: compute_queue_family_index, min_num_queues: 1); |
| 4176 | |
| 4177 | const float priorities[] = { 1.0f, 1.0f }; |
| 4178 | device->single_queue = compute_queue_family_index == transfer_queue_family_index && queue_family_props[compute_queue_family_index].queueCount == 1; |
| 4179 | |
| 4180 | std::vector<vk::DeviceQueueCreateInfo> device_queue_create_infos; |
| 4181 | if (compute_queue_family_index != transfer_queue_family_index) { |
| 4182 | device_queue_create_infos.push_back(x: {vk::DeviceQueueCreateFlags(), compute_queue_family_index, 1, priorities}); |
| 4183 | device_queue_create_infos.push_back(x: {vk::DeviceQueueCreateFlags(), transfer_queue_family_index, 1, priorities + 1}); |
| 4184 | } else if(!device->single_queue) { |
| 4185 | device_queue_create_infos.push_back(x: {vk::DeviceQueueCreateFlags(), compute_queue_family_index, 2, priorities}); |
| 4186 | } else { |
| 4187 | device_queue_create_infos.push_back(x: {vk::DeviceQueueCreateFlags(), compute_queue_family_index, 1, priorities}); |
| 4188 | } |
| 4189 | vk::DeviceCreateInfo device_create_info; |
| 4190 | std::vector<const char *> device_extensions; |
| 4191 | vk::PhysicalDeviceFeatures device_features = device->physical_device.getFeatures(); |
| 4192 | |
| 4193 | VkPhysicalDeviceFeatures2 device_features2; |
| 4194 | device_features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2; |
| 4195 | device_features2.pNext = nullptr; |
| 4196 | device_features2.features = (VkPhysicalDeviceFeatures)device_features; |
| 4197 | |
| 4198 | VkPhysicalDeviceVulkan11Features vk11_features; |
| 4199 | vk11_features.pNext = nullptr; |
| 4200 | vk11_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_1_FEATURES; |
| 4201 | device_features2.pNext = &vk11_features; |
| 4202 | |
| 4203 | VkPhysicalDeviceVulkan12Features vk12_features; |
| 4204 | vk12_features.pNext = nullptr; |
| 4205 | vk12_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_2_FEATURES; |
| 4206 | vk11_features.pNext = &vk12_features; |
| 4207 | |
| 4208 | last_struct = (VkBaseOutStructure *)&vk12_features; |
| 4209 | |
| 4210 | VkPhysicalDevicePipelineRobustnessFeaturesEXT pl_robustness_features; |
| 4211 | pl_robustness_features.pNext = nullptr; |
| 4212 | pl_robustness_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PIPELINE_ROBUSTNESS_FEATURES_EXT; |
| 4213 | pl_robustness_features.pipelineRobustness = VK_FALSE; |
| 4214 | |
| 4215 | if (pipeline_robustness) { |
| 4216 | last_struct->pNext = (VkBaseOutStructure *)&pl_robustness_features; |
| 4217 | last_struct = (VkBaseOutStructure *)&pl_robustness_features; |
| 4218 | device_extensions.push_back(x: "VK_EXT_pipeline_robustness" ); |
| 4219 | } |
| 4220 | |
| 4221 | VkPhysicalDeviceSubgroupSizeControlFeaturesEXT subgroup_size_control_features; |
| 4222 | subgroup_size_control_features.pNext = nullptr; |
| 4223 | subgroup_size_control_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SUBGROUP_SIZE_CONTROL_FEATURES_EXT; |
| 4224 | subgroup_size_control_features.computeFullSubgroups = false; |
| 4225 | subgroup_size_control_features.subgroupSizeControl = false; |
| 4226 | |
| 4227 | if (device->subgroup_size_control) { |
| 4228 | last_struct->pNext = (VkBaseOutStructure *)&subgroup_size_control_features; |
| 4229 | last_struct = (VkBaseOutStructure *)&subgroup_size_control_features; |
| 4230 | } |
| 4231 | |
| 4232 | #if defined(VK_KHR_cooperative_matrix) |
| 4233 | VkPhysicalDeviceCooperativeMatrixFeaturesKHR coopmat_features; |
| 4234 | coopmat_features.pNext = nullptr; |
| 4235 | coopmat_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_FEATURES_KHR; |
| 4236 | coopmat_features.cooperativeMatrix = VK_FALSE; |
| 4237 | |
| 4238 | if (device->coopmat_support) { |
| 4239 | last_struct->pNext = (VkBaseOutStructure *)&coopmat_features; |
| 4240 | last_struct = (VkBaseOutStructure *)&coopmat_features; |
| 4241 | } |
| 4242 | #endif |
| 4243 | |
| 4244 | #if defined(VK_NV_cooperative_matrix2) |
| 4245 | VkPhysicalDeviceCooperativeMatrix2FeaturesNV coopmat2_features {}; |
| 4246 | coopmat2_features.pNext = nullptr; |
| 4247 | coopmat2_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_2_FEATURES_NV; |
| 4248 | if (coopmat2_support) { |
| 4249 | last_struct->pNext = (VkBaseOutStructure *)&coopmat2_features; |
| 4250 | last_struct = (VkBaseOutStructure *)&coopmat2_features; |
| 4251 | device_extensions.push_back(x: "VK_NV_cooperative_matrix2" ); |
| 4252 | } |
| 4253 | #endif |
| 4254 | |
| 4255 | #if defined(VK_KHR_shader_bfloat16) |
| 4256 | VkPhysicalDeviceShaderBfloat16FeaturesKHR bfloat16_features {}; |
| 4257 | bfloat16_features.pNext = nullptr; |
| 4258 | bfloat16_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_BFLOAT16_FEATURES_KHR; |
| 4259 | if (bfloat16_support) { |
| 4260 | last_struct->pNext = (VkBaseOutStructure *)&bfloat16_features; |
| 4261 | last_struct = (VkBaseOutStructure *)&bfloat16_features; |
| 4262 | device_extensions.push_back(x: "VK_KHR_shader_bfloat16" ); |
| 4263 | } |
| 4264 | #endif |
| 4265 | |
| 4266 | VkPhysicalDeviceMaintenance4Features maint4_features {}; |
| 4267 | maint4_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_MAINTENANCE_4_FEATURES; |
| 4268 | if (maintenance4_support) { |
| 4269 | last_struct->pNext = (VkBaseOutStructure *)&maint4_features; |
| 4270 | last_struct = (VkBaseOutStructure *)&maint4_features; |
| 4271 | device_extensions.push_back(x: "VK_KHR_maintenance4" ); |
| 4272 | } |
| 4273 | |
| 4274 | VkPhysicalDeviceShaderIntegerDotProductFeaturesKHR shader_integer_dot_product_features {}; |
| 4275 | shader_integer_dot_product_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_INTEGER_DOT_PRODUCT_FEATURES_KHR; |
| 4276 | if (device->integer_dot_product) { |
| 4277 | last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_features; |
| 4278 | last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_features; |
| 4279 | device_extensions.push_back(x: "VK_KHR_shader_integer_dot_product" ); |
| 4280 | } |
| 4281 | |
| 4282 | VkPhysicalDevicePipelineExecutablePropertiesFeaturesKHR pep_features {}; |
| 4283 | pep_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PIPELINE_EXECUTABLE_PROPERTIES_FEATURES_KHR; |
| 4284 | if (pipeline_executable_properties_support) { |
| 4285 | last_struct->pNext = (VkBaseOutStructure *)&pep_features; |
| 4286 | last_struct = (VkBaseOutStructure *)&pep_features; |
| 4287 | device_extensions.push_back(x: "VK_KHR_pipeline_executable_properties" ); |
| 4288 | } |
| 4289 | |
| 4290 | vkGetPhysicalDeviceFeatures2(physicalDevice: device->physical_device, pFeatures: &device_features2); |
| 4291 | |
| 4292 | device->pipeline_executable_properties_support = pipeline_executable_properties_support; |
| 4293 | |
| 4294 | device->fp16 = device->fp16 && vk12_features.shaderFloat16; |
| 4295 | |
| 4296 | #if defined(VK_KHR_shader_bfloat16) |
| 4297 | device->bf16 = bfloat16_support && bfloat16_features.shaderBFloat16Type; |
| 4298 | #else |
| 4299 | device->bf16 = false; |
| 4300 | #endif |
| 4301 | |
| 4302 | device->pipeline_robustness = pl_robustness_features.pipelineRobustness; |
| 4303 | |
| 4304 | device->multi_add = vk12_props.shaderRoundingModeRTEFloat16 && |
| 4305 | device->properties.limits.maxPushConstantsSize >= sizeof(vk_op_multi_add_push_constants) && |
| 4306 | getenv(name: "GGML_VK_DISABLE_MULTI_ADD" ) == nullptr; |
| 4307 | |
| 4308 | device->shader_int64 = device_features2.features.shaderInt64; |
| 4309 | device->buffer_device_address = vk12_features.bufferDeviceAddress; |
| 4310 | |
| 4311 | if (device->subgroup_size_control) { |
| 4312 | device->subgroup_min_size = subgroup_size_control_props.minSubgroupSize; |
| 4313 | device->subgroup_max_size = subgroup_size_control_props.maxSubgroupSize; |
| 4314 | device_extensions.push_back(x: "VK_EXT_subgroup_size_control" ); |
| 4315 | } |
| 4316 | |
| 4317 | device->subgroup_size_control = device->subgroup_size_control && |
| 4318 | (subgroup_size_control_props.requiredSubgroupSizeStages & vk::ShaderStageFlagBits::eCompute) && |
| 4319 | subgroup_size_control_features.subgroupSizeControl; |
| 4320 | |
| 4321 | device->subgroup_require_full_support = subgroup_size_control_features.computeFullSubgroups; |
| 4322 | |
| 4323 | #if defined(VK_KHR_cooperative_matrix) |
| 4324 | device->coopmat_support = device->coopmat_support && coopmat_features.cooperativeMatrix; |
| 4325 | |
| 4326 | // coopmat1 fa shader currently assumes 32 invocations per subgroup |
| 4327 | device->coopmat1_fa_support = device->coopmat_support && device->subgroup_require_full_support && |
| 4328 | device->subgroup_size_control && device->subgroup_min_size <= 32 && |
| 4329 | device->subgroup_max_size >= 32; |
| 4330 | #endif |
| 4331 | |
| 4332 | if (coopmat2_support) { |
| 4333 | #if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) |
| 4334 | if (coopmat2_features.cooperativeMatrixWorkgroupScope && |
| 4335 | coopmat2_features.cooperativeMatrixFlexibleDimensions && |
| 4336 | coopmat2_features.cooperativeMatrixReductions && |
| 4337 | coopmat2_features.cooperativeMatrixConversions && |
| 4338 | coopmat2_features.cooperativeMatrixPerElementOperations && |
| 4339 | coopmat2_features.cooperativeMatrixTensorAddressing && |
| 4340 | coopmat2_features.cooperativeMatrixBlockLoads && |
| 4341 | vk12_features.bufferDeviceAddress) { |
| 4342 | |
| 4343 | std::vector<VkCooperativeMatrixFlexibleDimensionsPropertiesNV> flexible_dimensions; |
| 4344 | uint32_t count = 0; |
| 4345 | |
| 4346 | PFN_vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV |
| 4347 | _vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV = |
| 4348 | (PFN_vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV) |
| 4349 | vk_instance.instance.getProcAddr(pName: "vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV" ); |
| 4350 | |
| 4351 | _vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV(device->physical_device, &count, nullptr); |
| 4352 | |
| 4353 | VkCooperativeMatrixFlexibleDimensionsPropertiesNV empty_prop {}; |
| 4354 | empty_prop.sType = VK_STRUCTURE_TYPE_COOPERATIVE_MATRIX_FLEXIBLE_DIMENSIONS_PROPERTIES_NV; |
| 4355 | flexible_dimensions.resize(new_size: count, x: empty_prop); |
| 4356 | |
| 4357 | _vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV(device->physical_device, &count, flexible_dimensions.data()); |
| 4358 | |
| 4359 | bool found_fp16_128 = false, |
| 4360 | found_fp16_256 = false, |
| 4361 | found_fp32_128 = false, |
| 4362 | found_fp32_256 = false; |
| 4363 | // need to support fp16*fp16 with fp16/fp32 accumulator, for workgroupsize 128 |
| 4364 | // with 32x16x16 and 256 with 32x32x16. |
| 4365 | for (auto &prop : flexible_dimensions) { |
| 4366 | if (prop.saturatingAccumulation == VK_FALSE && |
| 4367 | prop.scope == VK_SCOPE_WORKGROUP_KHR && |
| 4368 | prop.AType == VK_COMPONENT_TYPE_FLOAT16_KHR && |
| 4369 | prop.BType == VK_COMPONENT_TYPE_FLOAT16_KHR) { |
| 4370 | |
| 4371 | if (prop.workgroupInvocations == 128 && |
| 4372 | prop.MGranularity <= 32 && |
| 4373 | prop.NGranularity <= 16 && |
| 4374 | prop.KGranularity <= 16) { |
| 4375 | if (prop.CType == VK_COMPONENT_TYPE_FLOAT16_KHR && |
| 4376 | prop.ResultType == VK_COMPONENT_TYPE_FLOAT16_KHR) { |
| 4377 | found_fp16_128 = true; |
| 4378 | } |
| 4379 | if (prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR && |
| 4380 | prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR) { |
| 4381 | found_fp32_128 = true; |
| 4382 | } |
| 4383 | } |
| 4384 | if (prop.workgroupInvocations == 256 && |
| 4385 | prop.MGranularity <= 32 && |
| 4386 | prop.NGranularity <= 32 && |
| 4387 | prop.KGranularity <= 16) { |
| 4388 | if (prop.CType == VK_COMPONENT_TYPE_FLOAT16_KHR && |
| 4389 | prop.ResultType == VK_COMPONENT_TYPE_FLOAT16_KHR) { |
| 4390 | found_fp16_256 = true; |
| 4391 | } |
| 4392 | if (prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR && |
| 4393 | prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR) { |
| 4394 | found_fp32_256 = true; |
| 4395 | } |
| 4396 | } |
| 4397 | } |
| 4398 | } |
| 4399 | if (found_fp16_128 && found_fp16_256 && |
| 4400 | found_fp32_128 && found_fp32_256 && |
| 4401 | coopmat2_props.cooperativeMatrixFlexibleDimensionsMaxDimension >= 512) { |
| 4402 | device->coopmat2 = true; |
| 4403 | } |
| 4404 | } |
| 4405 | #endif |
| 4406 | } |
| 4407 | |
| 4408 | if (!vk11_features.storageBuffer16BitAccess) { |
| 4409 | std::cerr << "ggml_vulkan: device " << GGML_VK_NAME << idx << " does not support 16-bit storage." << std::endl; |
| 4410 | throw std::runtime_error("Unsupported device" ); |
| 4411 | } |
| 4412 | |
| 4413 | device_extensions.push_back(x: "VK_KHR_16bit_storage" ); |
| 4414 | |
| 4415 | #ifdef GGML_VULKAN_VALIDATE |
| 4416 | device_extensions.push_back("VK_KHR_shader_non_semantic_info" ); |
| 4417 | #endif |
| 4418 | |
| 4419 | if (device->fp16) { |
| 4420 | device_extensions.push_back(x: "VK_KHR_shader_float16_int8" ); |
| 4421 | } |
| 4422 | |
| 4423 | #if defined(VK_KHR_cooperative_matrix) |
| 4424 | if (device->coopmat_support) { |
| 4425 | // Query supported shapes |
| 4426 | std::vector<VkCooperativeMatrixPropertiesKHR> cm_props; |
| 4427 | |
| 4428 | PFN_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR pfn_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR = |
| 4429 | (PFN_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR)vkGetInstanceProcAddr(instance: vk_instance.instance, pName: "vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR" ); |
| 4430 | |
| 4431 | uint32_t cm_props_num; |
| 4432 | |
| 4433 | pfn_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR(device->physical_device, &cm_props_num, nullptr); |
| 4434 | |
| 4435 | cm_props.resize(new_size: cm_props_num); |
| 4436 | |
| 4437 | for (auto& prop : cm_props) { |
| 4438 | prop.sType = VK_STRUCTURE_TYPE_COOPERATIVE_MATRIX_PROPERTIES_KHR; |
| 4439 | } |
| 4440 | |
| 4441 | pfn_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR(device->physical_device, &cm_props_num, cm_props.data()); |
| 4442 | |
| 4443 | VK_LOG_DEBUG("ggml_vulkan: Cooperative Matrix Shapes: " << cm_props.size()); |
| 4444 | |
| 4445 | for (auto& prop : cm_props) { |
| 4446 | VK_LOG_DEBUG("ggml_vulkan: M: " << prop.MSize << " N: " << prop.NSize << " K: " << prop.KSize << " A: " << vk::to_string((vk::ComponentTypeKHR)prop.AType) << " B: " << vk::to_string((vk::ComponentTypeKHR)prop.BType) << " C: " << vk::to_string((vk::ComponentTypeKHR)prop.CType) << " Result: " << vk::to_string((vk::ComponentTypeKHR)prop.ResultType) << " saturatingAccumulation: " << prop.saturatingAccumulation << " scope: " << vk::to_string((vk::ScopeKHR)prop.scope)); |
| 4447 | |
| 4448 | if ((vk::ComponentTypeKHR)prop.AType == vk::ComponentTypeKHR::eFloat16 && |
| 4449 | (vk::ComponentTypeKHR)prop.BType == vk::ComponentTypeKHR::eFloat16 && |
| 4450 | (vk::ScopeKHR)prop.scope == vk::ScopeKHR::eSubgroup |
| 4451 | ) { |
| 4452 | if ((vk::ComponentTypeKHR)prop.CType == vk::ComponentTypeKHR::eFloat32 && |
| 4453 | (vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eFloat32) { |
| 4454 | // coopmat sizes not set yet |
| 4455 | if (device->coopmat_m == 0) { |
| 4456 | device->coopmat_acc_f32_support = true; |
| 4457 | device->coopmat_m = prop.MSize; |
| 4458 | device->coopmat_n = prop.NSize; |
| 4459 | device->coopmat_k = prop.KSize; |
| 4460 | } else if (device->coopmat_m == prop.MSize && device->coopmat_n == prop.NSize && device->coopmat_k == prop.KSize) { |
| 4461 | // Only enable if shape is identical |
| 4462 | device->coopmat_acc_f32_support = true; |
| 4463 | } |
| 4464 | if (prop.MSize == 16 && prop.NSize == 16 && prop.KSize == 16) { |
| 4465 | device->coopmat_support_16x16x16_f32acc = true; |
| 4466 | } |
| 4467 | } else if ((vk::ComponentTypeKHR)prop.CType == vk::ComponentTypeKHR::eFloat16 && |
| 4468 | (vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eFloat16) { |
| 4469 | // coopmat sizes not set yet |
| 4470 | if (device->coopmat_m == 0) { |
| 4471 | device->coopmat_acc_f16_support = true; |
| 4472 | device->coopmat_m = prop.MSize; |
| 4473 | device->coopmat_n = prop.NSize; |
| 4474 | device->coopmat_k = prop.KSize; |
| 4475 | } else if (device->coopmat_m == prop.MSize && device->coopmat_n == prop.NSize && device->coopmat_k == prop.KSize) { |
| 4476 | // Only enable if shape is identical |
| 4477 | device->coopmat_acc_f16_support = true; |
| 4478 | } |
| 4479 | if (prop.MSize == 16 && prop.NSize == 16 && prop.KSize == 16) { |
| 4480 | device->coopmat_support_16x16x16_f16acc = true; |
| 4481 | } |
| 4482 | } |
| 4483 | } else if ((vk::ComponentTypeKHR)prop.AType == vk::ComponentTypeKHR::eSint8 && |
| 4484 | (vk::ComponentTypeKHR)prop.BType == vk::ComponentTypeKHR::eSint8 && |
| 4485 | (vk::ComponentTypeKHR)prop.CType == vk::ComponentTypeKHR::eSint32 && |
| 4486 | (vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eSint32 && |
| 4487 | (vk::ScopeKHR)prop.scope == vk::ScopeKHR::eSubgroup && |
| 4488 | device->coopmat_int_m == 0 |
| 4489 | ) { |
| 4490 | device->coopmat_int_support = true; |
| 4491 | device->coopmat_int_m = prop.MSize; |
| 4492 | device->coopmat_int_n = prop.NSize; |
| 4493 | device->coopmat_int_k = prop.KSize; |
| 4494 | } |
| 4495 | #if defined(VK_KHR_shader_bfloat16) && defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) |
| 4496 | if (prop.AType == VK_COMPONENT_TYPE_BFLOAT16_KHR && |
| 4497 | prop.BType == VK_COMPONENT_TYPE_BFLOAT16_KHR && |
| 4498 | prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR && |
| 4499 | prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR && |
| 4500 | (vk::ScopeKHR)prop.scope == vk::ScopeKHR::eSubgroup |
| 4501 | ) { |
| 4502 | // coopmat sizes not set yet |
| 4503 | if (device->coopmat_m == 0) { |
| 4504 | device->coopmat_bf16_support = true; |
| 4505 | device->coopmat_m = prop.MSize; |
| 4506 | device->coopmat_n = prop.NSize; |
| 4507 | device->coopmat_k = prop.KSize; |
| 4508 | } else if (device->coopmat_m == prop.MSize && device->coopmat_n == prop.NSize && device->coopmat_k == prop.KSize) { |
| 4509 | // Only enable if shape is identical |
| 4510 | device->coopmat_bf16_support = true; |
| 4511 | } |
| 4512 | } |
| 4513 | #endif |
| 4514 | } |
| 4515 | |
| 4516 | if (device->coopmat_m == 0 || !device->coopmat_acc_f32_support) { |
| 4517 | // No suitable matmul mode found |
| 4518 | GGML_LOG_DEBUG("ggml_vulkan: WARNING: No suitable matrix core mode found. Disabling matrix cores.\n" ); |
| 4519 | device->coopmat_support = false; |
| 4520 | } |
| 4521 | if (getenv(name: "GGML_VK_DISABLE_BFLOAT16" )) { |
| 4522 | device->coopmat_bf16_support = false; |
| 4523 | } |
| 4524 | } |
| 4525 | |
| 4526 | if (device->coopmat_support) { |
| 4527 | device_extensions.push_back(x: "VK_KHR_cooperative_matrix" ); |
| 4528 | } |
| 4529 | #if defined(VK_KHR_shader_bfloat16) |
| 4530 | if (device->coopmat_bf16_support) { |
| 4531 | device_extensions.push_back(x: "VK_KHR_shader_bfloat16" ); |
| 4532 | } |
| 4533 | #endif |
| 4534 | #endif |
| 4535 | device->name = GGML_VK_NAME + std::to_string(val: idx); |
| 4536 | |
| 4537 | device_create_info = { |
| 4538 | vk::DeviceCreateFlags(), |
| 4539 | device_queue_create_infos, |
| 4540 | {}, |
| 4541 | device_extensions |
| 4542 | }; |
| 4543 | device_create_info.setPNext(&device_features2); |
| 4544 | device->device = device->physical_device.createDevice(createInfo: device_create_info); |
| 4545 | |
| 4546 | // Queues |
| 4547 | ggml_vk_create_queue(device, q&: device->compute_queue, queue_family_index: compute_queue_family_index, queue_index: 0, stage_flags: { vk::PipelineStageFlagBits::eComputeShader | vk::PipelineStageFlagBits::eTransfer }, transfer_only: false); |
| 4548 | |
| 4549 | // Shaders |
| 4550 | // Disable matmul tile sizes early if performance low or not supported |
| 4551 | for (uint32_t i = 0; i < GGML_TYPE_COUNT; ++i) { |
| 4552 | switch (device->vendor_id) { |
| 4553 | #ifndef GGML_VULKAN_RUN_TESTS |
| 4554 | case VK_VENDOR_ID_AMD: |
| 4555 | case VK_VENDOR_ID_INTEL: |
| 4556 | device->mul_mat_l[i] = false; |
| 4557 | device->mul_mat_m[i] = true; |
| 4558 | device->mul_mat_s[i] = true; |
| 4559 | device->mul_mat_id_l[i] = false; |
| 4560 | device->mul_mat_id_m[i] = true; |
| 4561 | device->mul_mat_id_s[i] = true; |
| 4562 | break; |
| 4563 | case VK_VENDOR_ID_APPLE: |
| 4564 | device->mul_mat_l[i] = false; |
| 4565 | device->mul_mat_m[i] = true; |
| 4566 | device->mul_mat_s[i] = false; |
| 4567 | device->mul_mat_id_l[i] = false; |
| 4568 | device->mul_mat_id_m[i] = true; |
| 4569 | device->mul_mat_id_s[i] = false; |
| 4570 | break; |
| 4571 | #endif |
| 4572 | default: |
| 4573 | device->mul_mat_l[i] = true; |
| 4574 | device->mul_mat_m[i] = true; |
| 4575 | device->mul_mat_s[i] = true; |
| 4576 | device->mul_mat_id_l[i] = true; |
| 4577 | device->mul_mat_id_m[i] = true; |
| 4578 | device->mul_mat_id_s[i] = true; |
| 4579 | break; |
| 4580 | } |
| 4581 | } |
| 4582 | |
| 4583 | |
| 4584 | std::vector<vk::DescriptorSetLayoutBinding> dsl_binding; |
| 4585 | std::vector<vk::DescriptorBindingFlags> dsl_binding_flags; |
| 4586 | for (uint32_t i = 0; i < MAX_PARAMETER_COUNT; i++) { |
| 4587 | dsl_binding.push_back(x: {i, vk::DescriptorType::eStorageBuffer, 1, vk::ShaderStageFlagBits::eCompute}); |
| 4588 | dsl_binding_flags.push_back(x: {}); |
| 4589 | } |
| 4590 | |
| 4591 | vk::DescriptorSetLayoutBindingFlagsCreateInfo dslbfci = { dsl_binding_flags }; |
| 4592 | |
| 4593 | vk::DescriptorSetLayoutCreateInfo descriptor_set_layout_create_info( |
| 4594 | {}, |
| 4595 | dsl_binding); |
| 4596 | descriptor_set_layout_create_info.setPNext(&dslbfci); |
| 4597 | device->dsl = device->device.createDescriptorSetLayout(createInfo: descriptor_set_layout_create_info); |
| 4598 | |
| 4599 | ggml_vk_load_shaders(device); |
| 4600 | |
| 4601 | if (!device->single_queue) { |
| 4602 | const uint32_t transfer_queue_index = compute_queue_family_index == transfer_queue_family_index ? 1 : 0; |
| 4603 | ggml_vk_create_queue(device, q&: device->transfer_queue, queue_family_index: transfer_queue_family_index, queue_index: transfer_queue_index, stage_flags: { vk::PipelineStageFlagBits::eTransfer }, transfer_only: true); |
| 4604 | } else { |
| 4605 | // TODO: Use pointer or reference to avoid copy |
| 4606 | device->transfer_queue.copyFrom(other&: device->compute_queue); |
| 4607 | device->transfer_queue.cmd_pool.init(device, q_: &device->transfer_queue); |
| 4608 | } |
| 4609 | |
| 4610 | device->buffer_type = { |
| 4611 | /* .iface = */ ggml_backend_vk_buffer_type_interface, |
| 4612 | /* .device = */ ggml_backend_reg_dev_get(reg: ggml_backend_vk_reg(), index: idx), |
| 4613 | /* .context = */ new ggml_backend_vk_buffer_type_context{ .name: device->name, .device: device }, |
| 4614 | }; |
| 4615 | |
| 4616 | device->fence = device->device.createFence(createInfo: {}); |
| 4617 | |
| 4618 | device->idx = idx; |
| 4619 | |
| 4620 | device->disable_fusion = getenv(name: "GGML_VK_DISABLE_FUSION" ) != nullptr; |
| 4621 | |
| 4622 | device->add_rms_fusion = !device->disable_fusion && |
| 4623 | device->subgroup_arithmetic && |
| 4624 | device->vendor_id != VK_VENDOR_ID_INTEL; |
| 4625 | device->partials_binding_alignment = |
| 4626 | std::max(a: 4u, b: (uint32_t)device->properties.limits.minStorageBufferOffsetAlignment); |
| 4627 | |
| 4628 | device->mmvq_mode = 0; |
| 4629 | if (getenv(name: "GGML_VK_DISABLE_MMVQ" )) { |
| 4630 | device->mmvq_mode = -1; |
| 4631 | } else if (getenv(name: "GGML_VK_FORCE_MMVQ" )) { |
| 4632 | device->mmvq_mode = 1; |
| 4633 | } |
| 4634 | |
| 4635 | return device; |
| 4636 | } |
| 4637 | |
| 4638 | return vk_instance.devices[idx]; |
| 4639 | } |
| 4640 | |
| 4641 | static void ggml_vk_print_gpu_info(size_t idx) { |
| 4642 | GGML_ASSERT(idx < vk_instance.device_indices.size()); |
| 4643 | size_t dev_num = vk_instance.device_indices[idx]; |
| 4644 | VK_LOG_DEBUG("ggml_vk_print_gpu_info(" << dev_num << ")" ); |
| 4645 | GGML_ASSERT(vk_instance_initialized); |
| 4646 | |
| 4647 | std::vector<vk::PhysicalDevice> devices = vk_instance.instance.enumeratePhysicalDevices(); |
| 4648 | |
| 4649 | if (dev_num >= devices.size()) { |
| 4650 | std::cerr << "ggml_vulkan: Device with index " << dev_num << " does not exist." << std::endl; |
| 4651 | throw std::runtime_error("Device not found" ); |
| 4652 | } |
| 4653 | |
| 4654 | vk::PhysicalDevice physical_device = devices[dev_num]; |
| 4655 | std::vector<vk::ExtensionProperties> ext_props = physical_device.enumerateDeviceExtensionProperties(); |
| 4656 | |
| 4657 | bool fp16_storage = false; |
| 4658 | bool fp16_compute = false; |
| 4659 | bool coopmat_support = false; |
| 4660 | bool coopmat2_support = false; |
| 4661 | bool integer_dot_product = false; |
| 4662 | bool bfloat16_support = false; |
| 4663 | |
| 4664 | for (auto properties : ext_props) { |
| 4665 | if (strcmp(s1: "VK_KHR_16bit_storage" , s2: properties.extensionName) == 0) { |
| 4666 | fp16_storage = true; |
| 4667 | } else if (strcmp(s1: "VK_KHR_shader_float16_int8" , s2: properties.extensionName) == 0) { |
| 4668 | fp16_compute = true; |
| 4669 | #if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) |
| 4670 | } else if (strcmp(s1: "VK_KHR_cooperative_matrix" , s2: properties.extensionName) == 0 && |
| 4671 | !getenv(name: "GGML_VK_DISABLE_COOPMAT" )) { |
| 4672 | coopmat_support = true; |
| 4673 | #endif |
| 4674 | #if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) |
| 4675 | } else if (strcmp(s1: "VK_NV_cooperative_matrix2" , s2: properties.extensionName) == 0 && |
| 4676 | !getenv(name: "GGML_VK_DISABLE_COOPMAT2" )) { |
| 4677 | coopmat2_support = true; |
| 4678 | #endif |
| 4679 | #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) |
| 4680 | } else if (strcmp(s1: "VK_KHR_shader_integer_dot_product" , s2: properties.extensionName) == 0 && |
| 4681 | !getenv(name: "GGML_VK_DISABLE_INTEGER_DOT_PRODUCT" )) { |
| 4682 | integer_dot_product = true; |
| 4683 | #endif |
| 4684 | #if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) |
| 4685 | } else if (strcmp(s1: "VK_KHR_shader_bfloat16" , s2: properties.extensionName) == 0 && |
| 4686 | !getenv(name: "GGML_VK_DISABLE_BFLOAT16" )) { |
| 4687 | bfloat16_support = true; |
| 4688 | #endif |
| 4689 | } |
| 4690 | } |
| 4691 | |
| 4692 | const vk_device_architecture device_architecture = get_device_architecture(device: physical_device); |
| 4693 | |
| 4694 | const char* GGML_VK_DISABLE_F16 = getenv(name: "GGML_VK_DISABLE_F16" ); |
| 4695 | bool force_disable_f16 = GGML_VK_DISABLE_F16 != nullptr; |
| 4696 | |
| 4697 | bool fp16 = !force_disable_f16 && fp16_storage && fp16_compute; |
| 4698 | |
| 4699 | vk::PhysicalDeviceProperties2 props2; |
| 4700 | vk::PhysicalDeviceMaintenance3Properties props3; |
| 4701 | vk::PhysicalDeviceSubgroupProperties subgroup_props; |
| 4702 | vk::PhysicalDeviceDriverProperties driver_props; |
| 4703 | vk::PhysicalDeviceShaderIntegerDotProductPropertiesKHR shader_integer_dot_product_props; |
| 4704 | props2.pNext = &props3; |
| 4705 | props3.pNext = &subgroup_props; |
| 4706 | subgroup_props.pNext = &driver_props; |
| 4707 | |
| 4708 | // Pointer to the last chain element |
| 4709 | VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&driver_props; |
| 4710 | |
| 4711 | if (integer_dot_product) { |
| 4712 | last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_props; |
| 4713 | last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_props; |
| 4714 | } |
| 4715 | |
| 4716 | physical_device.getProperties2(pProperties: &props2); |
| 4717 | |
| 4718 | VkPhysicalDeviceFeatures2 device_features2; |
| 4719 | device_features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2; |
| 4720 | device_features2.pNext = nullptr; |
| 4721 | |
| 4722 | VkPhysicalDeviceVulkan11Features vk11_features; |
| 4723 | vk11_features.pNext = nullptr; |
| 4724 | vk11_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_1_FEATURES; |
| 4725 | device_features2.pNext = &vk11_features; |
| 4726 | |
| 4727 | VkPhysicalDeviceVulkan12Features vk12_features; |
| 4728 | vk12_features.pNext = nullptr; |
| 4729 | vk12_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_2_FEATURES; |
| 4730 | vk11_features.pNext = &vk12_features; |
| 4731 | |
| 4732 | // Pointer to the last chain element |
| 4733 | last_struct = (VkBaseOutStructure *)&vk12_features; |
| 4734 | |
| 4735 | #if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) |
| 4736 | VkPhysicalDeviceCooperativeMatrixFeaturesKHR coopmat_features; |
| 4737 | coopmat_features.pNext = nullptr; |
| 4738 | coopmat_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_FEATURES_KHR; |
| 4739 | coopmat_features.cooperativeMatrix = VK_FALSE; |
| 4740 | |
| 4741 | if (coopmat_support) { |
| 4742 | last_struct->pNext = (VkBaseOutStructure *)&coopmat_features; |
| 4743 | last_struct = (VkBaseOutStructure *)&coopmat_features; |
| 4744 | } |
| 4745 | #endif |
| 4746 | |
| 4747 | VkPhysicalDeviceShaderIntegerDotProductFeaturesKHR shader_integer_dot_product_features {}; |
| 4748 | shader_integer_dot_product_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_INTEGER_DOT_PRODUCT_FEATURES_KHR; |
| 4749 | if (integer_dot_product) { |
| 4750 | last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_features; |
| 4751 | last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_features; |
| 4752 | } |
| 4753 | |
| 4754 | #if defined(VK_KHR_shader_bfloat16) |
| 4755 | VkPhysicalDeviceShaderBfloat16FeaturesKHR bfloat16_features {}; |
| 4756 | bfloat16_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_BFLOAT16_FEATURES_KHR; |
| 4757 | if (bfloat16_support) { |
| 4758 | last_struct->pNext = (VkBaseOutStructure *)&bfloat16_features; |
| 4759 | last_struct = (VkBaseOutStructure *)&bfloat16_features; |
| 4760 | } |
| 4761 | #endif |
| 4762 | |
| 4763 | vkGetPhysicalDeviceFeatures2(physicalDevice: physical_device, pFeatures: &device_features2); |
| 4764 | |
| 4765 | fp16 = fp16 && vk12_features.shaderFloat16; |
| 4766 | |
| 4767 | #if defined(VK_KHR_shader_bfloat16) |
| 4768 | bool bf16 = bfloat16_support && bfloat16_features.shaderBFloat16Type; |
| 4769 | #else |
| 4770 | bool bf16 = false; |
| 4771 | #endif |
| 4772 | |
| 4773 | uint32_t default_subgroup_size = get_subgroup_size(pipeline_name: "" , arch: device_architecture); |
| 4774 | const size_t subgroup_size = (default_subgroup_size != 0) ? default_subgroup_size : subgroup_props.subgroupSize; |
| 4775 | const bool uma = props2.properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu; |
| 4776 | |
| 4777 | integer_dot_product = integer_dot_product |
| 4778 | && shader_integer_dot_product_props.integerDotProduct4x8BitPackedSignedAccelerated |
| 4779 | && shader_integer_dot_product_features.shaderIntegerDotProduct; |
| 4780 | |
| 4781 | coopmat_support = coopmat_support |
| 4782 | #if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) |
| 4783 | && coopmat_features.cooperativeMatrix |
| 4784 | #endif |
| 4785 | && ggml_vk_khr_cooperative_matrix_support(props: props2.properties, driver_props, arch: device_architecture); |
| 4786 | |
| 4787 | std::string matrix_cores = coopmat2_support ? "NV_coopmat2" : coopmat_support ? "KHR_coopmat" : "none" ; |
| 4788 | |
| 4789 | std::string device_name = props2.properties.deviceName.data(); |
| 4790 | GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | bf16: %d | warp size: %zu | shared memory: %d | int dot: %d | matrix cores: %s\n" , |
| 4791 | idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16, bf16, subgroup_size, |
| 4792 | props2.properties.limits.maxComputeSharedMemorySize, integer_dot_product, matrix_cores.c_str()); |
| 4793 | |
| 4794 | if (props2.properties.deviceType == vk::PhysicalDeviceType::eCpu) { |
| 4795 | GGML_LOG_DEBUG("ggml_vulkan: Warning: Device type is CPU. This is probably not the device you want.\n" ); |
| 4796 | } |
| 4797 | } |
| 4798 | |
| 4799 | static bool ggml_vk_instance_validation_ext_available(); |
| 4800 | static bool ggml_vk_instance_portability_enumeration_ext_available(const std::vector<vk::ExtensionProperties>& instance_extensions); |
| 4801 | static bool ggml_vk_instance_debug_utils_ext_available(const std::vector<vk::ExtensionProperties> & instance_extensions); |
| 4802 | static bool ggml_vk_device_is_supported(const vk::PhysicalDevice & vkdev); |
| 4803 | |
| 4804 | static DispatchLoaderDynamic ggml_vk_default_dispatcher_instance; |
| 4805 | DispatchLoaderDynamic & ggml_vk_default_dispatcher() { |
| 4806 | return ggml_vk_default_dispatcher_instance; |
| 4807 | } |
| 4808 | |
| 4809 | static void ggml_vk_instance_init() { |
| 4810 | if (vk_instance_initialized) { |
| 4811 | return; |
| 4812 | } |
| 4813 | VK_LOG_DEBUG("ggml_vk_instance_init()" ); |
| 4814 | |
| 4815 | // See https://github.com/KhronosGroup/Vulkan-Hpp?tab=readme-ov-file#extensions--per-device-function-pointers- |
| 4816 | ggml_vk_default_dispatcher_instance.init(getInstanceProcAddr: vkGetInstanceProcAddr); |
| 4817 | |
| 4818 | uint32_t api_version = vk::enumerateInstanceVersion(); |
| 4819 | |
| 4820 | if (api_version < VK_API_VERSION_1_2) { |
| 4821 | std::cerr << "ggml_vulkan: Error: Vulkan 1.2 required." << std::endl; |
| 4822 | throw vk::SystemError(vk::Result::eErrorFeatureNotPresent, "Vulkan 1.2 required" ); |
| 4823 | } |
| 4824 | |
| 4825 | vk::ApplicationInfo app_info{ "ggml-vulkan" , 1, nullptr, 0, api_version }; |
| 4826 | |
| 4827 | const std::vector<vk::ExtensionProperties> instance_extensions = vk::enumerateInstanceExtensionProperties(); |
| 4828 | const bool validation_ext = ggml_vk_instance_validation_ext_available(); |
| 4829 | #ifdef __APPLE__ |
| 4830 | const bool portability_enumeration_ext = ggml_vk_instance_portability_enumeration_ext_available(instance_extensions); |
| 4831 | #endif |
| 4832 | const bool debug_utils_ext = ggml_vk_instance_debug_utils_ext_available(instance_extensions) && getenv(name: "GGML_VK_DEBUG_MARKERS" ) != nullptr; |
| 4833 | std::vector<const char*> layers; |
| 4834 | |
| 4835 | if (validation_ext) { |
| 4836 | layers.push_back(x: "VK_LAYER_KHRONOS_validation" ); |
| 4837 | } |
| 4838 | std::vector<const char*> extensions; |
| 4839 | if (validation_ext) { |
| 4840 | extensions.push_back(x: "VK_EXT_validation_features" ); |
| 4841 | } |
| 4842 | #ifdef __APPLE__ |
| 4843 | if (portability_enumeration_ext) { |
| 4844 | extensions.push_back("VK_KHR_portability_enumeration" ); |
| 4845 | } |
| 4846 | #endif |
| 4847 | if (debug_utils_ext) { |
| 4848 | extensions.push_back(x: "VK_EXT_debug_utils" ); |
| 4849 | } |
| 4850 | vk::InstanceCreateInfo instance_create_info(vk::InstanceCreateFlags{}, &app_info, layers, extensions); |
| 4851 | #ifdef __APPLE__ |
| 4852 | if (portability_enumeration_ext) { |
| 4853 | instance_create_info.flags |= vk::InstanceCreateFlagBits::eEnumeratePortabilityKHR; |
| 4854 | } |
| 4855 | #endif |
| 4856 | |
| 4857 | std::vector<vk::ValidationFeatureEnableEXT> features_enable; |
| 4858 | vk::ValidationFeaturesEXT validation_features; |
| 4859 | |
| 4860 | if (validation_ext) { |
| 4861 | features_enable = { vk::ValidationFeatureEnableEXT::eBestPractices }; |
| 4862 | validation_features = { |
| 4863 | features_enable, |
| 4864 | {}, |
| 4865 | }; |
| 4866 | validation_features.setPNext(nullptr); |
| 4867 | instance_create_info.setPNext(&validation_features); |
| 4868 | GGML_LOG_DEBUG("ggml_vulkan: Validation layers enabled\n" ); |
| 4869 | } |
| 4870 | vk_instance.instance = vk::createInstance(createInfo: instance_create_info); |
| 4871 | vk_instance_initialized = true; |
| 4872 | |
| 4873 | if (debug_utils_ext) { |
| 4874 | vk_instance.debug_utils_support = true; |
| 4875 | vk_instance.pfn_vkSetDebugUtilsObjectNameEXT = (PFN_vkSetDebugUtilsObjectNameEXT) vkGetInstanceProcAddr(instance: vk_instance.instance, pName: "vkSetDebugUtilsObjectNameEXT" ); |
| 4876 | vk_instance.pfn_vkQueueBeginDebugUtilsLabelEXT = (PFN_vkQueueBeginDebugUtilsLabelEXT) vkGetInstanceProcAddr(instance: vk_instance.instance, pName: "vkQueueBeginDebugUtilsLabelEXT" ); |
| 4877 | vk_instance.pfn_vkQueueEndDebugUtilsLabelEXT = (PFN_vkQueueEndDebugUtilsLabelEXT) vkGetInstanceProcAddr(instance: vk_instance.instance, pName: "vkQueueEndDebugUtilsLabelEXT" ); |
| 4878 | vk_instance.pfn_vkCmdBeginDebugUtilsLabelEXT = (PFN_vkCmdBeginDebugUtilsLabelEXT) vkGetInstanceProcAddr(instance: vk_instance.instance, pName: "vkCmdBeginDebugUtilsLabelEXT" ); |
| 4879 | vk_instance.pfn_vkCmdEndDebugUtilsLabelEXT = (PFN_vkCmdEndDebugUtilsLabelEXT) vkGetInstanceProcAddr(instance: vk_instance.instance, pName: "vkCmdEndDebugUtilsLabelEXT" ); |
| 4880 | vk_instance.pfn_vkCmdInsertDebugUtilsLabelEXT = (PFN_vkCmdInsertDebugUtilsLabelEXT) vkGetInstanceProcAddr(instance: vk_instance.instance, pName: "vkCmdInsertDebugUtilsLabelEXT" ); |
| 4881 | } |
| 4882 | |
| 4883 | vk_perf_logger_enabled = getenv(name: "GGML_VK_PERF_LOGGER" ) != nullptr; |
| 4884 | |
| 4885 | // See https://github.com/KhronosGroup/Vulkan-Hpp?tab=readme-ov-file#extensions--per-device-function-pointers- |
| 4886 | VULKAN_HPP_DEFAULT_DISPATCHER.init(instanceCpp: vk_instance.instance); |
| 4887 | |
| 4888 | std::vector<vk::PhysicalDevice> devices = vk_instance.instance.enumeratePhysicalDevices(); |
| 4889 | |
| 4890 | // Emulate behavior of CUDA_VISIBLE_DEVICES for Vulkan |
| 4891 | char * devices_env = getenv(name: "GGML_VK_VISIBLE_DEVICES" ); |
| 4892 | if (devices_env != nullptr) { |
| 4893 | size_t num_available_devices = devices.size(); |
| 4894 | |
| 4895 | std::string devices(devices_env); |
| 4896 | std::replace(first: devices.begin(), last: devices.end(), old_value: ',', new_value: ' '); |
| 4897 | |
| 4898 | std::stringstream ss(devices); |
| 4899 | size_t tmp; |
| 4900 | while (ss >> tmp) { |
| 4901 | if(tmp >= num_available_devices) { |
| 4902 | std::cerr << "ggml_vulkan: Invalid device index " << tmp << " in GGML_VK_VISIBLE_DEVICES." << std::endl; |
| 4903 | throw std::runtime_error("Invalid Vulkan device index" ); |
| 4904 | } |
| 4905 | vk_instance.device_indices.push_back(x: tmp); |
| 4906 | } |
| 4907 | } else { |
| 4908 | // If no vulkan devices are found, return early |
| 4909 | if (devices.empty()) { |
| 4910 | GGML_LOG_INFO("ggml_vulkan: No devices found.\n" ); |
| 4911 | return; |
| 4912 | } |
| 4913 | |
| 4914 | // Default to using all dedicated GPUs |
| 4915 | for (size_t i = 0; i < devices.size(); i++) { |
| 4916 | vk::PhysicalDeviceProperties2 new_props; |
| 4917 | vk::PhysicalDeviceDriverProperties new_driver; |
| 4918 | vk::PhysicalDeviceIDProperties new_id; |
| 4919 | new_props.pNext = &new_driver; |
| 4920 | new_driver.pNext = &new_id; |
| 4921 | devices[i].getProperties2(pProperties: &new_props); |
| 4922 | |
| 4923 | if ((new_props.properties.deviceType == vk::PhysicalDeviceType::eDiscreteGpu || new_props.properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu) && ggml_vk_device_is_supported(vkdev: devices[i])) { |
| 4924 | // Check if there are two physical devices corresponding to the same GPU |
| 4925 | auto old_device = std::find_if( |
| 4926 | first: vk_instance.device_indices.begin(), |
| 4927 | last: vk_instance.device_indices.end(), |
| 4928 | pred: [&devices, &new_id](const size_t k){ |
| 4929 | vk::PhysicalDeviceProperties2 old_props; |
| 4930 | vk::PhysicalDeviceIDProperties old_id; |
| 4931 | old_props.pNext = &old_id; |
| 4932 | devices[k].getProperties2(pProperties: &old_props); |
| 4933 | |
| 4934 | bool equals = std::equal(first1: std::begin(cont&: old_id.deviceUUID), last1: std::end(cont&: old_id.deviceUUID), first2: std::begin(cont&: new_id.deviceUUID)); |
| 4935 | equals = equals || ( |
| 4936 | old_id.deviceLUIDValid && new_id.deviceLUIDValid && |
| 4937 | std::equal(first1: std::begin(cont&: old_id.deviceLUID), last1: std::end(cont&: old_id.deviceLUID), first2: std::begin(cont&: new_id.deviceLUID)) |
| 4938 | ); |
| 4939 | |
| 4940 | return equals; |
| 4941 | } |
| 4942 | ); |
| 4943 | if (old_device == vk_instance.device_indices.end()) { |
| 4944 | vk_instance.device_indices.push_back(x: i); |
| 4945 | } else { |
| 4946 | // There can be two physical devices corresponding to the same GPU if there are 2 different drivers |
| 4947 | // This can cause error when splitting layers aross the devices, need to keep only 1 |
| 4948 | VK_LOG_DEBUG("Device " << i << " and device " << *old_device << " have the same deviceUUID" ); |
| 4949 | |
| 4950 | vk::PhysicalDeviceProperties2 old_props; |
| 4951 | vk::PhysicalDeviceDriverProperties old_driver; |
| 4952 | old_props.pNext = &old_driver; |
| 4953 | devices[*old_device].getProperties2(pProperties: &old_props); |
| 4954 | |
| 4955 | std::map<vk::DriverId, int> driver_priorities {}; |
| 4956 | int old_priority = std::numeric_limits<int>::max(); |
| 4957 | int new_priority = std::numeric_limits<int>::max(); |
| 4958 | |
| 4959 | // Check https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/VkDriverId.html for the list of driver id |
| 4960 | // Smaller number -> higher priority |
| 4961 | switch (old_props.properties.vendorID) { |
| 4962 | case VK_VENDOR_ID_AMD: |
| 4963 | driver_priorities[vk::DriverId::eMesaRadv] = 1; |
| 4964 | driver_priorities[vk::DriverId::eAmdOpenSource] = 2; |
| 4965 | driver_priorities[vk::DriverId::eAmdProprietary] = 3; |
| 4966 | break; |
| 4967 | case VK_VENDOR_ID_INTEL: |
| 4968 | driver_priorities[vk::DriverId::eIntelOpenSourceMESA] = 1; |
| 4969 | driver_priorities[vk::DriverId::eIntelProprietaryWindows] = 2; |
| 4970 | break; |
| 4971 | case VK_VENDOR_ID_NVIDIA: |
| 4972 | driver_priorities[vk::DriverId::eNvidiaProprietary] = 1; |
| 4973 | #if defined(VK_API_VERSION_1_3) && VK_HEADER_VERSION >= 235 |
| 4974 | driver_priorities[vk::DriverId::eMesaNvk] = 2; |
| 4975 | #endif |
| 4976 | break; |
| 4977 | } |
| 4978 | driver_priorities[vk::DriverId::eMesaDozen] = 100; |
| 4979 | |
| 4980 | if (driver_priorities.count(x: old_driver.driverID)) { |
| 4981 | old_priority = driver_priorities[old_driver.driverID]; |
| 4982 | } |
| 4983 | if (driver_priorities.count(x: new_driver.driverID)) { |
| 4984 | new_priority = driver_priorities[new_driver.driverID]; |
| 4985 | } |
| 4986 | |
| 4987 | if (new_priority < old_priority) { |
| 4988 | auto r = std::remove(first: vk_instance.device_indices.begin(), last: vk_instance.device_indices.end(), value: *old_device); |
| 4989 | vk_instance.device_indices.erase(first: r, last: vk_instance.device_indices.end()); |
| 4990 | vk_instance.device_indices.push_back(x: i); |
| 4991 | |
| 4992 | VK_LOG_DEBUG("Prioritize device " << i << " driver " << new_driver.driverName << " over device " << *old_device << " driver " << old_driver.driverName); |
| 4993 | } |
| 4994 | else { |
| 4995 | VK_LOG_DEBUG("Prioritize device " << *old_device << " driver " << old_driver.driverName << " over device " << i << " driver " << new_driver.driverName << std::endl); |
| 4996 | } |
| 4997 | } |
| 4998 | } |
| 4999 | } |
| 5000 | |
| 5001 | // If no GPUs found, fall back to the first non-CPU device. |
| 5002 | // If only CPU devices are available, return without devices. |
| 5003 | if (vk_instance.device_indices.empty()) { |
| 5004 | for (size_t i = 0; i < devices.size(); i++) { |
| 5005 | if (devices[i].getProperties().deviceType != vk::PhysicalDeviceType::eCpu) { |
| 5006 | vk_instance.device_indices.push_back(x: i); |
| 5007 | break; |
| 5008 | } |
| 5009 | } |
| 5010 | } |
| 5011 | |
| 5012 | if (vk_instance.device_indices.empty()) { |
| 5013 | GGML_LOG_INFO("ggml_vulkan: No devices found.\n" ); |
| 5014 | return; |
| 5015 | } |
| 5016 | } |
| 5017 | GGML_LOG_DEBUG("ggml_vulkan: Found %zu Vulkan devices:\n" , vk_instance.device_indices.size()); |
| 5018 | |
| 5019 | for (size_t i = 0; i < vk_instance.device_indices.size(); i++) { |
| 5020 | vk::PhysicalDevice vkdev = devices[vk_instance.device_indices[i]]; |
| 5021 | std::vector<vk::ExtensionProperties> extensionprops = vkdev.enumerateDeviceExtensionProperties(); |
| 5022 | |
| 5023 | bool membudget_supported = false; |
| 5024 | for (const auto & ext : extensionprops) { |
| 5025 | if (strcmp(VK_EXT_MEMORY_BUDGET_EXTENSION_NAME, s2: ext.extensionName) == 0) { |
| 5026 | membudget_supported = true; |
| 5027 | break; |
| 5028 | } |
| 5029 | } |
| 5030 | |
| 5031 | vk_instance.device_supports_membudget.push_back(x: membudget_supported); |
| 5032 | |
| 5033 | ggml_vk_print_gpu_info(idx: i); |
| 5034 | } |
| 5035 | } |
| 5036 | |
| 5037 | static void ggml_vk_init(ggml_backend_vk_context * ctx, size_t idx) { |
| 5038 | VK_LOG_DEBUG("ggml_vk_init(" << ctx->name << ", " << idx << ")" ); |
| 5039 | ggml_vk_instance_init(); |
| 5040 | GGML_ASSERT(idx < vk_instance.device_indices.size()); |
| 5041 | |
| 5042 | ctx->name = GGML_VK_NAME + std::to_string(val: idx); |
| 5043 | |
| 5044 | ctx->device = ggml_vk_get_device(idx); |
| 5045 | |
| 5046 | ctx->semaphore_idx = 0; |
| 5047 | ctx->event_idx = 0; |
| 5048 | |
| 5049 | ctx->prealloc_size_x = 0; |
| 5050 | ctx->prealloc_size_y = 0; |
| 5051 | ctx->prealloc_size_split_k = 0; |
| 5052 | ctx->prealloc_size_add_rms_partials = 0; |
| 5053 | |
| 5054 | ctx->fence = ctx->device->device.createFence(createInfo: {}); |
| 5055 | ctx->almost_ready_fence = ctx->device->device.createFence(createInfo: {}); |
| 5056 | |
| 5057 | ctx->compute_cmd_pool.init(device&: ctx->device, q_: &ctx->device->compute_queue); |
| 5058 | ctx->transfer_cmd_pool.init(device&: ctx->device, q_: &ctx->device->transfer_queue); |
| 5059 | |
| 5060 | #ifdef GGML_VULKAN_CHECK_RESULTS |
| 5061 | const char* skip_checks = getenv("GGML_VULKAN_SKIP_CHECKS" ); |
| 5062 | vk_skip_checks = (skip_checks == NULL ? 0 : atoi(skip_checks)); |
| 5063 | const char* output_tensor = getenv("GGML_VULKAN_OUTPUT_TENSOR" ); |
| 5064 | vk_output_tensor = (output_tensor == NULL ? 0 : atoi(output_tensor)); |
| 5065 | #endif |
| 5066 | } |
| 5067 | |
| 5068 | static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type type) { |
| 5069 | VK_LOG_DEBUG("ggml_vk_get_to_fp16()" ); |
| 5070 | switch (type) { |
| 5071 | case GGML_TYPE_F32: |
| 5072 | case GGML_TYPE_Q4_0: |
| 5073 | case GGML_TYPE_Q4_1: |
| 5074 | case GGML_TYPE_Q5_0: |
| 5075 | case GGML_TYPE_Q5_1: |
| 5076 | case GGML_TYPE_Q8_0: |
| 5077 | case GGML_TYPE_Q2_K: |
| 5078 | case GGML_TYPE_Q3_K: |
| 5079 | case GGML_TYPE_Q4_K: |
| 5080 | case GGML_TYPE_Q5_K: |
| 5081 | case GGML_TYPE_Q6_K: |
| 5082 | case GGML_TYPE_IQ1_S: |
| 5083 | case GGML_TYPE_IQ1_M: |
| 5084 | case GGML_TYPE_IQ2_XXS: |
| 5085 | case GGML_TYPE_IQ2_XS: |
| 5086 | case GGML_TYPE_IQ2_S: |
| 5087 | case GGML_TYPE_IQ3_XXS: |
| 5088 | case GGML_TYPE_IQ3_S: |
| 5089 | case GGML_TYPE_IQ4_XS: |
| 5090 | case GGML_TYPE_IQ4_NL: |
| 5091 | case GGML_TYPE_MXFP4: |
| 5092 | break; |
| 5093 | default: |
| 5094 | return nullptr; |
| 5095 | } |
| 5096 | |
| 5097 | return ctx->device->pipeline_dequant[type]; |
| 5098 | } |
| 5099 | |
| 5100 | static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type, ggml_prec prec) { |
| 5101 | VK_LOG_DEBUG("ggml_vk_get_mul_mat_mat_pipeline(" << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ", " << prec << ")" ); |
| 5102 | if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) { |
| 5103 | return ctx->device->pipeline_matmul_f32; |
| 5104 | } |
| 5105 | if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) { |
| 5106 | return ctx->device->pipeline_matmul_f32_f16; |
| 5107 | } |
| 5108 | if (src0_type == GGML_TYPE_BF16 && src1_type == GGML_TYPE_BF16) { |
| 5109 | return ctx->device->pipeline_matmul_bf16; |
| 5110 | } |
| 5111 | if (prec == GGML_PREC_DEFAULT && ctx->device->fp16 && !(ctx->device->coopmat_support && !ctx->device->coopmat_acc_f16_support)) { |
| 5112 | if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) { |
| 5113 | return ctx->device->pipeline_matmul_f16_f32.f16acc; |
| 5114 | } |
| 5115 | if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) { |
| 5116 | return ctx->device->pipeline_matmul_f16.f16acc; |
| 5117 | } |
| 5118 | } else { |
| 5119 | if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) { |
| 5120 | return ctx->device->pipeline_matmul_f16_f32.f32acc; |
| 5121 | } |
| 5122 | if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) { |
| 5123 | return ctx->device->pipeline_matmul_f16.f32acc; |
| 5124 | } |
| 5125 | } |
| 5126 | |
| 5127 | // MMQ |
| 5128 | if (src1_type == GGML_TYPE_Q8_1) { |
| 5129 | vk_matmul_pipeline pipelines = ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f32acc; |
| 5130 | |
| 5131 | if (pipelines->is_empty()) { |
| 5132 | return nullptr; |
| 5133 | } |
| 5134 | |
| 5135 | return pipelines; |
| 5136 | } |
| 5137 | |
| 5138 | if (src1_type != GGML_TYPE_F32 && !ctx->device->coopmat2) { |
| 5139 | return nullptr; |
| 5140 | } |
| 5141 | |
| 5142 | switch (src0_type) { |
| 5143 | case GGML_TYPE_Q4_0: |
| 5144 | case GGML_TYPE_Q4_1: |
| 5145 | case GGML_TYPE_Q5_0: |
| 5146 | case GGML_TYPE_Q5_1: |
| 5147 | case GGML_TYPE_Q8_0: |
| 5148 | case GGML_TYPE_Q2_K: |
| 5149 | case GGML_TYPE_Q3_K: |
| 5150 | case GGML_TYPE_Q4_K: |
| 5151 | case GGML_TYPE_Q5_K: |
| 5152 | case GGML_TYPE_Q6_K: |
| 5153 | case GGML_TYPE_IQ1_S: |
| 5154 | case GGML_TYPE_IQ1_M: |
| 5155 | case GGML_TYPE_IQ2_XXS: |
| 5156 | case GGML_TYPE_IQ2_XS: |
| 5157 | case GGML_TYPE_IQ2_S: |
| 5158 | case GGML_TYPE_IQ3_XXS: |
| 5159 | case GGML_TYPE_IQ3_S: |
| 5160 | case GGML_TYPE_IQ4_XS: |
| 5161 | case GGML_TYPE_IQ4_NL: |
| 5162 | case GGML_TYPE_MXFP4: |
| 5163 | break; |
| 5164 | default: |
| 5165 | return nullptr; |
| 5166 | } |
| 5167 | |
| 5168 | if (ctx->device->coopmat2) { |
| 5169 | assert(src1_type == GGML_TYPE_F16); |
| 5170 | return prec == GGML_PREC_DEFAULT ? ctx->device->pipeline_dequant_mul_mat_mat_f16[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat_f16[src0_type].f32acc; |
| 5171 | } |
| 5172 | if (ctx->device->coopmat_support) { |
| 5173 | return (ctx->device->fp16 && ctx->device->coopmat_acc_f16_support && prec == GGML_PREC_DEFAULT) ? ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f32acc; |
| 5174 | } |
| 5175 | return (ctx->device->fp16 && prec == GGML_PREC_DEFAULT) ? ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f32acc; |
| 5176 | } |
| 5177 | |
| 5178 | static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type, uint32_t num_cols, uint32_t m, uint32_t k) { |
| 5179 | VK_LOG_DEBUG("ggml_vk_get_dequantize_mul_mat_vec()" ); |
| 5180 | GGML_ASSERT(b_type == GGML_TYPE_F32 || b_type == GGML_TYPE_F16 || b_type == GGML_TYPE_Q8_1); |
| 5181 | GGML_ASSERT(num_cols >= 1 && num_cols <= mul_mat_vec_max_cols); |
| 5182 | |
| 5183 | if (b_type == GGML_TYPE_Q8_1) { |
| 5184 | switch (a_type) { |
| 5185 | case GGML_TYPE_Q4_0: |
| 5186 | case GGML_TYPE_Q4_1: |
| 5187 | case GGML_TYPE_Q5_0: |
| 5188 | case GGML_TYPE_Q5_1: |
| 5189 | case GGML_TYPE_Q8_0: |
| 5190 | break; |
| 5191 | default: |
| 5192 | return nullptr; |
| 5193 | } |
| 5194 | } |
| 5195 | |
| 5196 | switch (a_type) { |
| 5197 | case GGML_TYPE_F32: |
| 5198 | case GGML_TYPE_F16: |
| 5199 | case GGML_TYPE_BF16: |
| 5200 | case GGML_TYPE_Q4_0: |
| 5201 | case GGML_TYPE_Q4_1: |
| 5202 | case GGML_TYPE_Q5_0: |
| 5203 | case GGML_TYPE_Q5_1: |
| 5204 | case GGML_TYPE_Q8_0: |
| 5205 | case GGML_TYPE_Q2_K: |
| 5206 | case GGML_TYPE_Q3_K: |
| 5207 | case GGML_TYPE_Q4_K: |
| 5208 | case GGML_TYPE_Q5_K: |
| 5209 | case GGML_TYPE_Q6_K: |
| 5210 | case GGML_TYPE_IQ1_S: |
| 5211 | case GGML_TYPE_IQ1_M: |
| 5212 | case GGML_TYPE_IQ2_XXS: |
| 5213 | case GGML_TYPE_IQ2_XS: |
| 5214 | case GGML_TYPE_IQ2_S: |
| 5215 | case GGML_TYPE_IQ3_XXS: |
| 5216 | case GGML_TYPE_IQ3_S: |
| 5217 | case GGML_TYPE_IQ4_XS: |
| 5218 | case GGML_TYPE_IQ4_NL: |
| 5219 | case GGML_TYPE_MXFP4: |
| 5220 | break; |
| 5221 | default: |
| 5222 | return nullptr; |
| 5223 | } |
| 5224 | |
| 5225 | // heuristic to choose workgroup size |
| 5226 | uint32_t dmmv_wg = DMMV_WG_SIZE_SUBGROUP; |
| 5227 | if ((ctx->device->vendor_id == VK_VENDOR_ID_NVIDIA && ctx->device->architecture != vk_device_architecture::NVIDIA_PRE_TURING) || ctx->device->vendor_id == VK_VENDOR_ID_INTEL) { |
| 5228 | // Prefer larger workgroups when M is small, to spread the work out more |
| 5229 | // and keep more SMs busy. |
| 5230 | // q6_k seems to prefer small workgroup size even for "medium" values of M. |
| 5231 | if (a_type == GGML_TYPE_Q6_K) { |
| 5232 | if (m < 4096 && k >= 1024) { |
| 5233 | dmmv_wg = DMMV_WG_SIZE_LARGE; |
| 5234 | } |
| 5235 | } else { |
| 5236 | if (m <= 8192 && k >= 1024) { |
| 5237 | dmmv_wg = DMMV_WG_SIZE_LARGE; |
| 5238 | } |
| 5239 | } |
| 5240 | } |
| 5241 | |
| 5242 | if (b_type == GGML_TYPE_Q8_1) { |
| 5243 | if (ctx->device->vendor_id == VK_VENDOR_ID_INTEL) { |
| 5244 | dmmv_wg = DMMV_WG_SIZE_SUBGROUP; |
| 5245 | } |
| 5246 | return ctx->device->pipeline_dequant_mul_mat_vec_q8_1_f32[dmmv_wg][a_type][num_cols-1]; |
| 5247 | } |
| 5248 | |
| 5249 | return b_type == GGML_TYPE_F32 ? ctx->device->pipeline_dequant_mul_mat_vec_f32_f32[dmmv_wg][a_type][num_cols-1] : ctx->device->pipeline_dequant_mul_mat_vec_f16_f32[dmmv_wg][a_type][num_cols-1]; |
| 5250 | } |
| 5251 | |
| 5252 | static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type, ggml_prec prec) { |
| 5253 | VK_LOG_DEBUG("ggml_vk_get_mul_mat_mat_id_pipeline()" ); |
| 5254 | if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) { |
| 5255 | return ctx->device->pipeline_matmul_id_f32; |
| 5256 | } |
| 5257 | if (src0_type == GGML_TYPE_BF16 && src1_type == GGML_TYPE_BF16) { |
| 5258 | return ctx->device->pipeline_matmul_id_bf16; |
| 5259 | } |
| 5260 | if (prec == GGML_PREC_DEFAULT && ctx->device->fp16 && !(ctx->device->coopmat_support && !ctx->device->coopmat_acc_f16_support)) { |
| 5261 | if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) { |
| 5262 | return ctx->device->pipeline_matmul_id_f16_f32.f16acc; |
| 5263 | } |
| 5264 | if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) { |
| 5265 | return ctx->device->pipeline_matmul_id_f16.f16acc; |
| 5266 | } |
| 5267 | } else { |
| 5268 | if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) { |
| 5269 | return ctx->device->pipeline_matmul_id_f16_f32.f32acc; |
| 5270 | } |
| 5271 | if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) { |
| 5272 | return ctx->device->pipeline_matmul_id_f16.f32acc; |
| 5273 | } |
| 5274 | } |
| 5275 | |
| 5276 | // MMQ |
| 5277 | if (src1_type == GGML_TYPE_Q8_1) { |
| 5278 | vk_matmul_pipeline pipelines = ctx->device->pipeline_dequant_mul_mat_mat_id_q8_1[src0_type].f32acc; |
| 5279 | |
| 5280 | if (pipelines->is_empty()) { |
| 5281 | return nullptr; |
| 5282 | } |
| 5283 | |
| 5284 | return pipelines; |
| 5285 | } |
| 5286 | |
| 5287 | GGML_ASSERT(src1_type == GGML_TYPE_F32 || (ctx->device->coopmat2 && src1_type == GGML_TYPE_F16)); |
| 5288 | |
| 5289 | switch (src0_type) { |
| 5290 | case GGML_TYPE_Q4_0: |
| 5291 | case GGML_TYPE_Q4_1: |
| 5292 | case GGML_TYPE_Q5_0: |
| 5293 | case GGML_TYPE_Q5_1: |
| 5294 | case GGML_TYPE_Q8_0: |
| 5295 | case GGML_TYPE_Q2_K: |
| 5296 | case GGML_TYPE_Q3_K: |
| 5297 | case GGML_TYPE_Q4_K: |
| 5298 | case GGML_TYPE_Q5_K: |
| 5299 | case GGML_TYPE_Q6_K: |
| 5300 | case GGML_TYPE_IQ1_S: |
| 5301 | case GGML_TYPE_IQ1_M: |
| 5302 | case GGML_TYPE_IQ2_XXS: |
| 5303 | case GGML_TYPE_IQ2_XS: |
| 5304 | case GGML_TYPE_IQ2_S: |
| 5305 | case GGML_TYPE_IQ3_XXS: |
| 5306 | case GGML_TYPE_IQ3_S: |
| 5307 | case GGML_TYPE_IQ4_XS: |
| 5308 | case GGML_TYPE_IQ4_NL: |
| 5309 | case GGML_TYPE_MXFP4: |
| 5310 | break; |
| 5311 | default: |
| 5312 | return nullptr; |
| 5313 | } |
| 5314 | |
| 5315 | vk_matmul_pipeline2& mmp = ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type]; |
| 5316 | // XXX TODO 'prec' is not actually allowed in mul_mat_id. |
| 5317 | bool prefer_fp16acc = ctx->device->fp16 /*&& prec == GGML_PREC_DEFAULT*/; |
| 5318 | bool support_fp16acc = !mmp.f16acc->is_empty(); |
| 5319 | bool support_fp32acc = !mmp.f32acc->is_empty(); |
| 5320 | |
| 5321 | if (support_fp16acc && (prefer_fp16acc || !support_fp32acc)) { |
| 5322 | return mmp.f16acc; |
| 5323 | } else { |
| 5324 | GGML_ASSERT(support_fp32acc); |
| 5325 | return mmp.f32acc; |
| 5326 | } |
| 5327 | } |
| 5328 | |
| 5329 | static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type) { |
| 5330 | VK_LOG_DEBUG("ggml_vk_get_dequantize_mul_mat_vec_id()" ); |
| 5331 | GGML_ASSERT(b_type == GGML_TYPE_F32); |
| 5332 | |
| 5333 | switch (a_type) { |
| 5334 | case GGML_TYPE_F32: |
| 5335 | case GGML_TYPE_F16: |
| 5336 | case GGML_TYPE_BF16: |
| 5337 | case GGML_TYPE_Q4_0: |
| 5338 | case GGML_TYPE_Q4_1: |
| 5339 | case GGML_TYPE_Q5_0: |
| 5340 | case GGML_TYPE_Q5_1: |
| 5341 | case GGML_TYPE_Q8_0: |
| 5342 | case GGML_TYPE_Q2_K: |
| 5343 | case GGML_TYPE_Q3_K: |
| 5344 | case GGML_TYPE_Q4_K: |
| 5345 | case GGML_TYPE_Q5_K: |
| 5346 | case GGML_TYPE_Q6_K: |
| 5347 | case GGML_TYPE_IQ1_S: |
| 5348 | case GGML_TYPE_IQ1_M: |
| 5349 | case GGML_TYPE_IQ2_XXS: |
| 5350 | case GGML_TYPE_IQ2_XS: |
| 5351 | case GGML_TYPE_IQ2_S: |
| 5352 | case GGML_TYPE_IQ3_XXS: |
| 5353 | case GGML_TYPE_IQ3_S: |
| 5354 | case GGML_TYPE_IQ4_XS: |
| 5355 | case GGML_TYPE_IQ4_NL: |
| 5356 | case GGML_TYPE_MXFP4: |
| 5357 | break; |
| 5358 | default: |
| 5359 | return nullptr; |
| 5360 | } |
| 5361 | |
| 5362 | return ctx->device->pipeline_dequant_mul_mat_vec_id_f32[a_type]; |
| 5363 | } |
| 5364 | |
| 5365 | static void * ggml_vk_host_malloc(vk_device& device, size_t size) { |
| 5366 | VK_LOG_MEMORY("ggml_vk_host_malloc(" << size << ")" ); |
| 5367 | vk_buffer buf = ggml_vk_create_buffer(device, size, |
| 5368 | req_flags_list: {vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached, |
| 5369 | vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent}); |
| 5370 | |
| 5371 | if(!(buf->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible)) { |
| 5372 | fprintf(stderr, format: "WARNING: failed to allocate %.2f MB of pinned memory\n" , |
| 5373 | size/1024.0/1024.0); |
| 5374 | device->device.freeMemory(memory: buf->device_memory); |
| 5375 | device->device.destroyBuffer(buffer: buf->buffer); |
| 5376 | return nullptr; |
| 5377 | } |
| 5378 | |
| 5379 | std::lock_guard<std::recursive_mutex> guard(device->mutex); |
| 5380 | device->pinned_memory.push_back(x: std::make_tuple(args&: buf->ptr, args&: size, args&: buf)); |
| 5381 | |
| 5382 | return buf->ptr; |
| 5383 | } |
| 5384 | |
| 5385 | static void ggml_vk_host_free(vk_device& device, void* ptr) { |
| 5386 | if (ptr == nullptr) { |
| 5387 | return; |
| 5388 | } |
| 5389 | VK_LOG_MEMORY("ggml_vk_host_free(" << ptr << ")" ); |
| 5390 | std::lock_guard<std::recursive_mutex> guard(device->mutex); |
| 5391 | |
| 5392 | vk_buffer buf; |
| 5393 | size_t index; |
| 5394 | for (size_t i = 0; i < device->pinned_memory.size(); i++) { |
| 5395 | const uint8_t* addr = (const uint8_t*) std::get<0>(t&: device->pinned_memory[i]); |
| 5396 | const uint8_t* endr = addr + std::get<1>(t&: device->pinned_memory[i]); |
| 5397 | if (ptr >= addr && ptr < endr) { |
| 5398 | buf = std::get<2>(t&: device->pinned_memory[i]); |
| 5399 | index = i; |
| 5400 | break; |
| 5401 | } |
| 5402 | } |
| 5403 | if (buf == nullptr) { |
| 5404 | fprintf(stderr, format: "WARNING: failed to free pinned memory: memory not in map\n" ); |
| 5405 | return; |
| 5406 | } |
| 5407 | |
| 5408 | ggml_vk_destroy_buffer(buf); |
| 5409 | |
| 5410 | device->pinned_memory.erase(position: device->pinned_memory.begin() + index); |
| 5411 | } |
| 5412 | |
| 5413 | static void ggml_vk_host_get(const vk_device& device, const void * ptr, vk_buffer& buf, size_t& buf_offset) { |
| 5414 | std::lock_guard<std::recursive_mutex> guard(device->mutex); |
| 5415 | buf = nullptr; |
| 5416 | buf_offset = 0; |
| 5417 | for (size_t i = 0; i < device->pinned_memory.size(); i++) { |
| 5418 | const uint8_t* addr = (const uint8_t*) std::get<0>(t&: device->pinned_memory[i]); |
| 5419 | const uint8_t* endr = addr + std::get<1>(t&: device->pinned_memory[i]); |
| 5420 | if (ptr >= addr && ptr < endr) { |
| 5421 | buf = std::get<2>(t&: device->pinned_memory[i]); |
| 5422 | buf_offset = ((const uint8_t *)ptr) - addr; |
| 5423 | break; |
| 5424 | } |
| 5425 | } |
| 5426 | } |
| 5427 | |
| 5428 | static vk_subbuffer ggml_vk_tensor_subbuffer( |
| 5429 | const ggml_backend_vk_context * ctx, const ggml_tensor * tensor, bool allow_misalign = false) { |
| 5430 | |
| 5431 | vk_buffer buffer = nullptr; |
| 5432 | size_t offset = 0; |
| 5433 | if (ctx->device->uma) { |
| 5434 | ggml_vk_host_get(device: ctx->device, ptr: tensor->data, buf&: buffer, buf_offset&: offset); |
| 5435 | } |
| 5436 | if (!buffer) { |
| 5437 | auto buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context; |
| 5438 | buffer = buf_ctx->dev_buffer; |
| 5439 | offset = vk_tensor_offset(tensor) + tensor->view_offs; |
| 5440 | } |
| 5441 | GGML_ASSERT(buffer != nullptr); |
| 5442 | |
| 5443 | size_t size = ggml_nbytes(tensor); |
| 5444 | |
| 5445 | size_t misalign_bytes = offset & (ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1); |
| 5446 | // The shader must support misaligned offsets when indexing into the buffer |
| 5447 | GGML_ASSERT(allow_misalign || misalign_bytes == 0); |
| 5448 | offset &= ~misalign_bytes; |
| 5449 | size += misalign_bytes; |
| 5450 | |
| 5451 | return vk_subbuffer{.buffer: buffer, .offset: offset, .size: size}; |
| 5452 | } |
| 5453 | |
| 5454 | static vk_submission ggml_vk_begin_submission(vk_device& device, vk_command_pool& p, bool one_time = true) { |
| 5455 | vk_submission s; |
| 5456 | s.buffer = ggml_vk_create_cmd_buffer(device, p); |
| 5457 | if (one_time) { |
| 5458 | s.buffer.begin(beginInfo: { vk::CommandBufferUsageFlagBits::eOneTimeSubmit }); |
| 5459 | } else { |
| 5460 | s.buffer.begin(beginInfo: { vk::CommandBufferUsageFlags{} }); |
| 5461 | } |
| 5462 | |
| 5463 | return s; |
| 5464 | } |
| 5465 | |
| 5466 | template <typename T> size_t push_constant_size(const T &t) { |
| 5467 | static_assert(std::is_class<T>::value, "T must be a struct/class" ); |
| 5468 | GGML_UNUSED(t); |
| 5469 | return sizeof(T); |
| 5470 | } |
| 5471 | template <typename T> size_t push_constant_size(const std::vector<T> &t) { |
| 5472 | GGML_UNUSED(t); |
| 5473 | return sizeof(T) * t.size(); |
| 5474 | } |
| 5475 | template <typename T, uint32_t N> size_t push_constant_size(const std::array<T, N> &t) { |
| 5476 | GGML_UNUSED(t); |
| 5477 | return sizeof(T) * N; |
| 5478 | } |
| 5479 | |
| 5480 | template <typename T> const T *push_constant_data(const T &t) { |
| 5481 | static_assert(std::is_class<T>::value, "T must be a struct/class" ); |
| 5482 | return &t; |
| 5483 | } |
| 5484 | template <typename T> const T *push_constant_data(const std::vector<T> &t) { |
| 5485 | return t.data(); |
| 5486 | } |
| 5487 | template <typename T, uint32_t N> const T *push_constant_data(const std::array<T, N> &t) { |
| 5488 | return t.data(); |
| 5489 | } |
| 5490 | |
| 5491 | template <typename T> |
| 5492 | static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context& subctx, vk_pipeline& pipeline, std::initializer_list<vk::DescriptorBufferInfo> const& descriptor_buffer_infos, const T &push_constants, std::array<uint32_t, 3> elements) { |
| 5493 | const uint32_t wg0 = CEIL_DIV(elements[0], pipeline->wg_denoms[0]); |
| 5494 | const uint32_t wg1 = CEIL_DIV(elements[1], pipeline->wg_denoms[1]); |
| 5495 | const uint32_t wg2 = CEIL_DIV(elements[2], pipeline->wg_denoms[2]); |
| 5496 | VK_LOG_DEBUG("ggml_vk_dispatch_pipeline(" << pipeline->name << ", {" ; |
| 5497 | for (auto& buffer : descriptor_buffer_infos) { |
| 5498 | std::cerr << "(" << buffer.buffer << ", " << buffer.offset << ", " << buffer.range << "), " ; |
| 5499 | } |
| 5500 | std::cerr << "}, (" << wg0 << "," << wg1 << "," << wg2 << "))" ); |
| 5501 | GGML_ASSERT(ctx->descriptor_set_idx < ctx->descriptor_sets.size()); |
| 5502 | GGML_ASSERT(descriptor_buffer_infos.size() <= MAX_PARAMETER_COUNT); |
| 5503 | GGML_ASSERT(pipeline->parameter_count == descriptor_buffer_infos.size()); |
| 5504 | |
| 5505 | vk::DescriptorSet& descriptor_set = ctx->descriptor_sets[ctx->descriptor_set_idx++]; |
| 5506 | vk::WriteDescriptorSet write_descriptor_set{ descriptor_set, 0, 0, pipeline->parameter_count, vk::DescriptorType::eStorageBuffer, nullptr, descriptor_buffer_infos.begin() }; |
| 5507 | ctx->device->device.updateDescriptorSets(descriptorWrites: { write_descriptor_set }, descriptorCopies: {}); |
| 5508 | |
| 5509 | subctx->s->buffer.pushConstants(pipeline->layout, vk::ShaderStageFlagBits::eCompute, 0, push_constant_size(push_constants), push_constant_data(push_constants)); |
| 5510 | subctx->s->buffer.bindPipeline(pipelineBindPoint: vk::PipelineBindPoint::eCompute, pipeline: pipeline->pipeline); |
| 5511 | subctx->s->buffer.bindDescriptorSets(pipelineBindPoint: vk::PipelineBindPoint::eCompute, |
| 5512 | layout: pipeline->layout, |
| 5513 | firstSet: 0, |
| 5514 | descriptorSets: { descriptor_set }, |
| 5515 | dynamicOffsets: {}); |
| 5516 | subctx->s->buffer.dispatch(groupCountX: wg0, groupCountY: wg1, groupCountZ: wg2); |
| 5517 | } |
| 5518 | |
| 5519 | static void ggml_vk_end_submission(vk_submission& s, std::vector<vk_semaphore> wait_semaphores, std::vector<vk_semaphore> signal_semaphores) { |
| 5520 | s.buffer.end(); |
| 5521 | |
| 5522 | s.wait_semaphores = std::move(wait_semaphores); |
| 5523 | s.signal_semaphores = std::move(signal_semaphores); |
| 5524 | } |
| 5525 | |
| 5526 | static void ggml_vk_ctx_end(vk_context& ctx) { |
| 5527 | VK_LOG_DEBUG("ggml_vk_ctx_end(" << ctx << ", " << ctx->seqs.size() << ")" ); |
| 5528 | if (ctx->s == nullptr) { |
| 5529 | return; |
| 5530 | } |
| 5531 | |
| 5532 | ctx->s->buffer.end(); |
| 5533 | ctx->s = nullptr; |
| 5534 | } |
| 5535 | |
| 5536 | static void ggml_vk_ctx_begin(vk_device& device, vk_context& subctx) { |
| 5537 | VK_LOG_DEBUG("ggml_vk_ctx_begin(" << device->name << ")" ); |
| 5538 | if (subctx->s != nullptr) { |
| 5539 | ggml_vk_ctx_end(ctx&: subctx); |
| 5540 | } |
| 5541 | |
| 5542 | subctx->seqs.push_back(x: { ggml_vk_begin_submission(device, p&: *subctx->p) }); |
| 5543 | subctx->s = subctx->seqs[subctx->seqs.size() - 1].data(); |
| 5544 | } |
| 5545 | |
| 5546 | static size_t ggml_vk_align_size(size_t width, size_t align) { |
| 5547 | VK_LOG_DEBUG("ggml_vk_align_size(" << width << ", " << align << ")" ); |
| 5548 | return CEIL_DIV(width, align) * align; |
| 5549 | } |
| 5550 | |
| 5551 | static void deferred_memcpy(void * dst, const void * src, size_t size, std::vector<vk_staging_memcpy>* memcpys = nullptr) { |
| 5552 | if (memcpys == nullptr) { |
| 5553 | memcpy(dest: dst, src: src, n: size); |
| 5554 | } else { |
| 5555 | memcpys->emplace_back(args&: dst, args&: src, args&: size); |
| 5556 | } |
| 5557 | } |
| 5558 | |
| 5559 | static void deferred_memset(void * dst, uint32_t val, size_t size, std::vector<vk_staging_memset>* memsets = nullptr) { |
| 5560 | if (memsets == nullptr) { |
| 5561 | memset(s: dst, c: val, n: size); |
| 5562 | } else { |
| 5563 | memsets->emplace_back(args&: dst, args&: val, args&: size); |
| 5564 | } |
| 5565 | } |
| 5566 | |
| 5567 | static void ggml_vk_ensure_sync_staging_buffer(vk_device& device, size_t size) { |
| 5568 | if (device->sync_staging == nullptr || device->sync_staging->size < size) { |
| 5569 | VK_LOG_MEMORY("ggml_vk_ensure_sync_staging_buffer(" << size << ")" ); |
| 5570 | ggml_vk_destroy_buffer(buf&: device->sync_staging); |
| 5571 | device->sync_staging = ggml_vk_create_buffer_check(device, size, |
| 5572 | req_flags: vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached, |
| 5573 | fallback_flags: vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent); |
| 5574 | } |
| 5575 | } |
| 5576 | |
| 5577 | static void ggml_vk_buffer_write_nc_async(ggml_backend_vk_context * ctx, vk_context& subctx, vk_buffer& dst, size_t offset, const ggml_tensor * tensor, bool sync_staging = false) { |
| 5578 | VK_LOG_DEBUG("ggml_vk_buffer_write_nc_async(" << tensor << ")" ); |
| 5579 | GGML_ASSERT(!ggml_is_contiguous(tensor)); |
| 5580 | // Buffer is already mapped |
| 5581 | if(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) { |
| 5582 | std::cerr << "ggml_vulkan: buffer_write_nc_async dst buffer is host_visible. Use synchronous write." << std::endl; |
| 5583 | GGML_ABORT("fatal error" ); |
| 5584 | } |
| 5585 | // Check if src is pinned memory |
| 5586 | vk_buffer buf = nullptr; |
| 5587 | size_t buf_offset = 0; |
| 5588 | ggml_vk_host_get(device: ctx->device, ptr: tensor->data, buf, buf_offset); |
| 5589 | |
| 5590 | const uint64_t ne0 = tensor->ne[0]; |
| 5591 | const uint64_t ne1 = tensor->ne[1]; |
| 5592 | const uint64_t ne2 = tensor->ne[2]; |
| 5593 | const uint64_t ne3 = tensor->ne[3]; |
| 5594 | const uint64_t nb0 = tensor->nb[0]; |
| 5595 | const uint64_t nb1 = tensor->nb[1]; |
| 5596 | const uint64_t nb2 = tensor->nb[2]; |
| 5597 | const uint64_t nb3 = tensor->nb[3]; |
| 5598 | const ggml_type type = tensor->type; |
| 5599 | const uint64_t ts = ggml_type_size(type); |
| 5600 | const uint64_t bs = ggml_blck_size(type); |
| 5601 | |
| 5602 | const uint64_t dstnb0 = ts; |
| 5603 | const uint64_t dstnb1 = dstnb0*(ne0/bs); |
| 5604 | const uint64_t dstnb2 = dstnb1*ne1; |
| 5605 | const uint64_t dstnb3 = dstnb2*ne2; |
| 5606 | |
| 5607 | const uint64_t ne = ggml_nelements(tensor); |
| 5608 | |
| 5609 | if (buf != nullptr) { |
| 5610 | // Memory is pinned, use as staging buffer |
| 5611 | std::vector<vk::BufferCopy> slices; |
| 5612 | |
| 5613 | for (uint64_t i3 = 0; i3 < ne3; i3++) { |
| 5614 | for (uint64_t i2 = 0; i2 < ne2; i2++) { |
| 5615 | // Find longest contiguous slice |
| 5616 | if (ne1*nb1 == dstnb2) { |
| 5617 | slices.push_back(x: { buf_offset + i3*nb3 + i2*nb2, offset + i3*dstnb3 + i2*dstnb2, dstnb2 }); |
| 5618 | } else { |
| 5619 | for (uint64_t i1 = 0; i1 < ne1; i1++) { |
| 5620 | if (ne0*nb0/bs == dstnb1) { |
| 5621 | slices.push_back(x: { buf_offset + i3*nb3 + i2*nb2 + i1*nb1, offset + i3*dstnb3 + i2*dstnb2 + i1*dstnb1, dstnb1 }); |
| 5622 | } else { |
| 5623 | const uint64_t s_off = buf_offset + i3*nb3 + i2*nb2 + i1*nb1; |
| 5624 | const uint64_t d_off = offset + i3*dstnb3 + i2*dstnb2 + i1*dstnb1; |
| 5625 | for (uint64_t i0 = 0; i0 < ne0; i0++) { |
| 5626 | slices.push_back(x: { s_off + i1*nb0, d_off + i0*dstnb0, dstnb0 }); |
| 5627 | } |
| 5628 | } |
| 5629 | } |
| 5630 | } |
| 5631 | } |
| 5632 | } |
| 5633 | |
| 5634 | ggml_vk_sync_buffers(ctx, subctx); |
| 5635 | subctx->s->buffer.copyBuffer(srcBuffer: buf->buffer, dstBuffer: dst->buffer, regions: slices); |
| 5636 | return; |
| 5637 | } |
| 5638 | |
| 5639 | if (!sync_staging) { |
| 5640 | GGML_ABORT("Asynchronous write to non-pinned memory not supported" ); |
| 5641 | } |
| 5642 | |
| 5643 | // Staging buffer required |
| 5644 | vk_buffer& staging = ctx->device->sync_staging; |
| 5645 | const uint64_t copy_size = ts*ne/bs; |
| 5646 | ggml_vk_ensure_sync_staging_buffer(device&: ctx->device, size: copy_size); |
| 5647 | VkBufferCopy buf_copy{ .srcOffset: 0, .dstOffset: offset, .size: copy_size }; |
| 5648 | |
| 5649 | ggml_vk_sync_buffers(ctx, subctx); |
| 5650 | vkCmdCopyBuffer(commandBuffer: subctx->s->buffer, srcBuffer: (VkBuffer)staging->buffer, dstBuffer: (VkBuffer)dst->buffer, regionCount: 1, pRegions: &buf_copy); |
| 5651 | |
| 5652 | for (uint64_t i3 = 0; i3 < ne3; i3++) { |
| 5653 | for (uint64_t i2 = 0; i2 < ne2; i2++) { |
| 5654 | // Find longest contiguous slice |
| 5655 | if (ne1*nb1 == dstnb2) { |
| 5656 | deferred_memcpy(dst: (uint8_t *)staging->ptr + i3*dstnb3 + i2*dstnb2, src: (const uint8_t *) tensor->data + buf_offset + i3*nb3 + i2*nb2, size: dstnb2, memcpys: &subctx->in_memcpys); |
| 5657 | } else { |
| 5658 | for (uint64_t i1 = 0; i1 < ne1; i1++) { |
| 5659 | if (ne0*nb0/bs == dstnb1) { |
| 5660 | deferred_memcpy(dst: (uint8_t *)staging->ptr + i3*dstnb3 + i2*dstnb2 + i1*dstnb1, src: (const uint8_t *) tensor->data + buf_offset + i3*nb3 + i2*nb2 + i1*nb1, size: dstnb1, memcpys: &subctx->in_memcpys); |
| 5661 | } else { |
| 5662 | const uint64_t s_off = buf_offset + i3*nb3 + i2*nb2 + i1*nb1; |
| 5663 | const uint64_t d_off = i3*dstnb3 + i2*dstnb2 + i1*dstnb1; |
| 5664 | for (uint64_t i0 = 0; i0 < ne0; i0++) { |
| 5665 | deferred_memcpy(dst: (uint8_t *)staging->ptr + d_off + i0*dstnb0, src: (const uint8_t *) tensor->data + s_off + i0*nb0, size: dstnb0, memcpys: &subctx->in_memcpys); |
| 5666 | } |
| 5667 | } |
| 5668 | } |
| 5669 | } |
| 5670 | } |
| 5671 | } |
| 5672 | } |
| 5673 | |
| 5674 | static void ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t width, size_t height, bool sync_staging = false) { |
| 5675 | VK_LOG_DEBUG("ggml_vk_buffer_write_2d_async(" << width << ", " << height << ")" ); |
| 5676 | // Buffer is already mapped |
| 5677 | if(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) { |
| 5678 | std::cerr << "ggml_vulkan: buffer_write_async dst buffer is host_visible. Use synchronous write." << std::endl; |
| 5679 | GGML_ABORT("fatal error" ); |
| 5680 | } |
| 5681 | // Check if src is pinned memory |
| 5682 | vk_buffer buf = nullptr; |
| 5683 | size_t buf_offset = 0; |
| 5684 | ggml_vk_host_get(device: dst->device, ptr: src, buf, buf_offset); |
| 5685 | |
| 5686 | if (buf != nullptr) { |
| 5687 | // Memory is pinned, use as staging buffer |
| 5688 | std::vector<vk::BufferCopy> slices(1); |
| 5689 | if (width == spitch) { |
| 5690 | // Only do single write if stride is equal |
| 5691 | slices[0].srcOffset = buf_offset; |
| 5692 | slices[0].dstOffset = offset; |
| 5693 | slices[0].size = width * height; |
| 5694 | } else { |
| 5695 | slices.resize(new_size: height); |
| 5696 | for (size_t i = 0; i < height; i++) { |
| 5697 | slices[i].srcOffset = buf_offset + i * spitch; |
| 5698 | slices[i].dstOffset = offset + i * width; |
| 5699 | slices[i].size = width; |
| 5700 | } |
| 5701 | } |
| 5702 | |
| 5703 | ggml_vk_sync_buffers(ctx: nullptr, subctx); |
| 5704 | subctx->s->buffer.copyBuffer(srcBuffer: buf->buffer, dstBuffer: dst->buffer, regions: slices); |
| 5705 | return; |
| 5706 | } |
| 5707 | VK_LOG_DEBUG("STAGING" ); |
| 5708 | |
| 5709 | if (!sync_staging) { |
| 5710 | GGML_ABORT("Asynchronous write to non-pinned memory not supported" ); |
| 5711 | } |
| 5712 | |
| 5713 | // Staging buffer required |
| 5714 | const size_t copy_size = width*height; |
| 5715 | ggml_vk_ensure_sync_staging_buffer(device&: dst->device, size: copy_size); |
| 5716 | |
| 5717 | vk_buffer& staging_buffer = dst->device->sync_staging; |
| 5718 | |
| 5719 | VkBufferCopy buf_copy = { |
| 5720 | .srcOffset: 0, |
| 5721 | .dstOffset: offset, |
| 5722 | .size: copy_size}; |
| 5723 | |
| 5724 | ggml_vk_sync_buffers(ctx: nullptr, subctx); |
| 5725 | vkCmdCopyBuffer(commandBuffer: subctx->s->buffer, srcBuffer: (VkBuffer)staging_buffer->buffer, dstBuffer: (VkBuffer)dst->buffer, regionCount: 1, pRegions: &buf_copy); |
| 5726 | |
| 5727 | if (width == spitch) { |
| 5728 | deferred_memcpy(dst: (uint8_t *)staging_buffer->ptr, src, size: width * height, memcpys: &subctx->in_memcpys); |
| 5729 | } else { |
| 5730 | for (size_t i = 0; i < height; i++) { |
| 5731 | deferred_memcpy(dst: (uint8_t *)staging_buffer->ptr + i * width, src: (const uint8_t *) src + i * spitch, size: width, memcpys: &subctx->in_memcpys); |
| 5732 | } |
| 5733 | } |
| 5734 | } |
| 5735 | |
| 5736 | static void ggml_vk_buffer_write_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t size, bool sync_staging = false) { |
| 5737 | VK_LOG_DEBUG("ggml_vk_buffer_write_async(" << size << ")" ); |
| 5738 | return ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, spitch: size, width: size, height: 1, sync_staging); |
| 5739 | } |
| 5740 | |
| 5741 | static void ggml_vk_buffer_write_2d(vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t width, size_t height) { |
| 5742 | VK_LOG_DEBUG("ggml_vk_buffer_write_2d(" << width << ", " << height << ")" ); |
| 5743 | // Buffer is already mapped |
| 5744 | if(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) { |
| 5745 | GGML_ASSERT(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostCoherent); |
| 5746 | |
| 5747 | for (size_t i = 0; i < height; i++) { |
| 5748 | memcpy(dest: (uint8_t *)dst->ptr + offset + i * width, src: (const uint8_t *) src + i * spitch, n: width); |
| 5749 | } |
| 5750 | } else { |
| 5751 | std::lock_guard<std::recursive_mutex> guard(dst->device->mutex); |
| 5752 | |
| 5753 | vk_context subctx = ggml_vk_create_temporary_context(p&: dst->device->transfer_queue.cmd_pool); |
| 5754 | ggml_vk_ctx_begin(device&: dst->device, subctx); |
| 5755 | ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, spitch, width, height, sync_staging: true); |
| 5756 | ggml_vk_ctx_end(ctx&: subctx); |
| 5757 | |
| 5758 | for (auto& cpy : subctx->in_memcpys) { |
| 5759 | memcpy(dest: cpy.dst, src: cpy.src, n: cpy.n); |
| 5760 | } |
| 5761 | |
| 5762 | for (auto& mset : subctx->memsets) { |
| 5763 | memset(s: mset.dst, c: mset.val, n: mset.n); |
| 5764 | } |
| 5765 | |
| 5766 | ggml_vk_submit(ctx&: subctx, fence: dst->device->fence); |
| 5767 | VK_CHECK(dst->device->device.waitForFences({ dst->device->fence }, true, UINT64_MAX), "vk_buffer_write_2d waitForFences" ); |
| 5768 | dst->device->device.resetFences(fences: { dst->device->fence }); |
| 5769 | ggml_vk_queue_command_pools_cleanup(device&: dst->device); |
| 5770 | } |
| 5771 | } |
| 5772 | |
| 5773 | static void ggml_vk_buffer_write(vk_buffer& dst, size_t offset, const void * src, size_t size) { |
| 5774 | VK_LOG_DEBUG("ggml_vk_buffer_write(" << size << ")" ); |
| 5775 | ggml_vk_buffer_write_2d(dst, offset, src, spitch: 0, width: size, height: 1); |
| 5776 | } |
| 5777 | |
| 5778 | static void ggml_vk_buffer_read_2d_async(vk_context subctx, vk_buffer& src, size_t offset, void * dst, size_t spitch, size_t dpitch, size_t width, size_t height, bool sync_staging = false) { |
| 5779 | VK_LOG_DEBUG("ggml_vk_buffer_read_2d_async(offset=" << offset << ", width=" << width << ", height=" << height << ")" ); |
| 5780 | GGML_ASSERT(width > 0); |
| 5781 | GGML_ASSERT(height > 0); |
| 5782 | GGML_ASSERT(src != nullptr); |
| 5783 | |
| 5784 | // TODO: staging_offset is not used |
| 5785 | |
| 5786 | // Check if dst is pinned memory |
| 5787 | vk_buffer buf = nullptr; |
| 5788 | size_t buf_offset = 0; |
| 5789 | ggml_vk_host_get(device: src->device, ptr: dst, buf, buf_offset); |
| 5790 | |
| 5791 | std::vector<vk::BufferCopy> slices(1); |
| 5792 | if (width == spitch && width == dpitch) { |
| 5793 | // Only do single write if stride is equal |
| 5794 | slices[0].srcOffset = offset; |
| 5795 | slices[0].dstOffset = buf_offset; |
| 5796 | slices[0].size = width * height; |
| 5797 | } else { |
| 5798 | slices.resize(new_size: height); |
| 5799 | for (size_t i = 0; i < height; i++) { |
| 5800 | slices[i].srcOffset = offset + i * spitch; |
| 5801 | slices[i].dstOffset = buf_offset + i * dpitch; |
| 5802 | slices[i].size = width; |
| 5803 | } |
| 5804 | } |
| 5805 | |
| 5806 | if (buf != nullptr) { |
| 5807 | // Memory is pinned, use as staging buffer |
| 5808 | ggml_vk_sync_buffers(ctx: nullptr, subctx); |
| 5809 | subctx->s->buffer.copyBuffer(srcBuffer: src->buffer, dstBuffer: buf->buffer, regions: slices); |
| 5810 | |
| 5811 | return; |
| 5812 | } |
| 5813 | VK_LOG_DEBUG("STAGING" ); |
| 5814 | |
| 5815 | if (!sync_staging) { |
| 5816 | GGML_ABORT("Asynchronous read from non-pinned memory not supported" ); |
| 5817 | } |
| 5818 | |
| 5819 | // Fall back to staging buffer |
| 5820 | const size_t copy_size = dpitch * height; |
| 5821 | ggml_vk_ensure_sync_staging_buffer(device&: src->device, size: copy_size); |
| 5822 | |
| 5823 | vk_buffer& staging_buffer = src->device->sync_staging; |
| 5824 | |
| 5825 | ggml_vk_sync_buffers(ctx: nullptr, subctx); |
| 5826 | subctx->s->buffer.copyBuffer(srcBuffer: src->buffer, dstBuffer: staging_buffer->buffer, regions: slices); |
| 5827 | |
| 5828 | deferred_memcpy(dst, src: staging_buffer->ptr, size: copy_size, memcpys: &subctx->out_memcpys); |
| 5829 | } |
| 5830 | |
| 5831 | static void ggml_vk_buffer_read_async(vk_context subctx, vk_buffer& src, size_t offset, void * dst, size_t size, bool sync_staging = false) { |
| 5832 | return ggml_vk_buffer_read_2d_async(subctx, src, offset, dst, spitch: size, dpitch: size, width: size, height: 1, sync_staging); |
| 5833 | } |
| 5834 | |
| 5835 | static void ggml_vk_buffer_read(vk_buffer& src, size_t offset, void * dst, size_t size) { |
| 5836 | VK_LOG_DEBUG("ggml_vk_buffer_read(" << src->buffer << ", " << offset << ", " << size << ")" ); |
| 5837 | |
| 5838 | // If the device is not an UMA device the memory is host-accessible through rebar. While writing |
| 5839 | // through PCIe is sufficient fast reading back data from PCIe is slower than going through |
| 5840 | // the HW device to host copy path. |
| 5841 | if(src->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible && src->device->uma) { |
| 5842 | GGML_ASSERT(src->memory_property_flags & vk::MemoryPropertyFlagBits::eHostCoherent); |
| 5843 | |
| 5844 | memcpy(dest: dst, src: (uint8_t *) src->ptr + offset, n: size); |
| 5845 | } else { |
| 5846 | std::lock_guard<std::recursive_mutex> guard(src->device->mutex); |
| 5847 | |
| 5848 | vk_context subctx = ggml_vk_create_temporary_context(p&: src->device->transfer_queue.cmd_pool); |
| 5849 | ggml_vk_ctx_begin(device&: src->device, subctx); |
| 5850 | ggml_vk_buffer_read_async(subctx, src, offset, dst, size, sync_staging: true); |
| 5851 | ggml_vk_ctx_end(ctx&: subctx); |
| 5852 | |
| 5853 | ggml_vk_submit(ctx&: subctx, fence: src->device->fence); |
| 5854 | VK_CHECK(src->device->device.waitForFences({ src->device->fence }, true, UINT64_MAX), "vk_buffer_read waitForFences" ); |
| 5855 | src->device->device.resetFences(fences: { src->device->fence }); |
| 5856 | ggml_vk_queue_command_pools_cleanup(device&: src->device); |
| 5857 | |
| 5858 | for (auto& cpy : subctx->out_memcpys) { |
| 5859 | memcpy(dest: cpy.dst, src: cpy.src, n: cpy.n); |
| 5860 | } |
| 5861 | } |
| 5862 | } |
| 5863 | |
| 5864 | static void ggml_vk_buffer_copy_async(vk_context& ctx, vk_buffer& dst, size_t dst_offset, vk_buffer& src, size_t src_offset, size_t size) { |
| 5865 | VK_LOG_DEBUG("ggml_vk_buffer_copy_async(" << size << ")" ); |
| 5866 | // Make sure both buffers are on same device |
| 5867 | GGML_ASSERT(src->device == dst->device); |
| 5868 | |
| 5869 | VkBufferCopy bc{ .srcOffset: src_offset, .dstOffset: dst_offset, .size: size }; |
| 5870 | |
| 5871 | vkCmdCopyBuffer(commandBuffer: ctx->s->buffer, srcBuffer: (VkBuffer)src->buffer, dstBuffer: (VkBuffer)dst->buffer, regionCount: 1, pRegions: &bc); |
| 5872 | } |
| 5873 | |
| 5874 | static void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& src, size_t src_offset, size_t size) { |
| 5875 | if (src->device == dst->device) { |
| 5876 | std::lock_guard<std::recursive_mutex> guard(src->device->mutex); |
| 5877 | VK_LOG_DEBUG("ggml_vk_buffer_copy(SINGLE_DEVICE, " << size << ")" ); |
| 5878 | // Copy within the device |
| 5879 | vk_context subctx = ggml_vk_create_temporary_context(p&: src->device->transfer_queue.cmd_pool); |
| 5880 | ggml_vk_ctx_begin(device&: src->device, subctx); |
| 5881 | ggml_vk_buffer_copy_async(ctx&: subctx, dst, dst_offset, src, src_offset, size); |
| 5882 | ggml_vk_ctx_end(ctx&: subctx); |
| 5883 | ggml_vk_submit(ctx&: subctx, fence: src->device->fence); |
| 5884 | VK_CHECK(src->device->device.waitForFences({ src->device->fence }, true, UINT64_MAX), "vk_buffer_copy waitForFences" ); |
| 5885 | src->device->device.resetFences(fences: { src->device->fence }); |
| 5886 | ggml_vk_queue_command_pools_cleanup(device&: src->device); |
| 5887 | } else { |
| 5888 | VK_LOG_DEBUG("ggml_vk_buffer_copy(MULTI_DEVICE, " << size << ")" ); |
| 5889 | // Copy device to device |
| 5890 | ggml_vk_ensure_sync_staging_buffer(device&: src->device, size); |
| 5891 | |
| 5892 | // Copy to src staging buffer |
| 5893 | ggml_vk_buffer_copy(dst&: src->device->sync_staging, dst_offset: 0, src, src_offset, size); |
| 5894 | // Copy to dst buffer |
| 5895 | ggml_vk_buffer_write_2d(dst, offset: dst_offset, src: src->device->sync_staging->ptr, spitch: 0, width: size, height: 1); |
| 5896 | } |
| 5897 | } |
| 5898 | |
| 5899 | static void ggml_vk_buffer_memset_async(vk_context& ctx, vk_buffer& dst, size_t offset, uint32_t c, size_t size) { |
| 5900 | VK_LOG_DEBUG("ggml_vk_buffer_memset_async(" << offset << ", " << c << ", " << size << ")" ); |
| 5901 | |
| 5902 | if (dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible && |
| 5903 | dst->device->uma) { |
| 5904 | deferred_memset(dst: (uint8_t*)dst->ptr + offset, val: c, size, memsets: &ctx->memsets); |
| 5905 | return; |
| 5906 | } |
| 5907 | |
| 5908 | // Fall back to GPU fillBuffer for non-UMA or non-host-visible buffers |
| 5909 | ctx->s->buffer.fillBuffer(dstBuffer: dst->buffer, dstOffset: offset, size, data: c); |
| 5910 | } |
| 5911 | |
| 5912 | static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, size_t size) { |
| 5913 | VK_LOG_DEBUG("ggml_vk_buffer_memset(" << offset << ", " << c << ", " << size << ")" ); |
| 5914 | |
| 5915 | if (dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible && |
| 5916 | dst->device->uma) { |
| 5917 | memset(s: (uint8_t*)dst->ptr + offset, c: c, n: size); |
| 5918 | return; |
| 5919 | } |
| 5920 | |
| 5921 | std::lock_guard<std::recursive_mutex> guard(dst->device->mutex); |
| 5922 | vk_context subctx = ggml_vk_create_temporary_context(p&: dst->device->transfer_queue.cmd_pool); |
| 5923 | ggml_vk_ctx_begin(device&: dst->device, subctx); |
| 5924 | subctx->s->buffer.fillBuffer(dstBuffer: dst->buffer, dstOffset: offset, size, data: c); |
| 5925 | ggml_vk_ctx_end(ctx&: subctx); |
| 5926 | |
| 5927 | ggml_vk_submit(ctx&: subctx, fence: dst->device->fence); |
| 5928 | VK_CHECK(dst->device->device.waitForFences({ dst->device->fence }, true, UINT64_MAX), "vk_memset waitForFences" ); |
| 5929 | dst->device->device.resetFences(fences: { dst->device->fence }); |
| 5930 | ggml_vk_queue_command_pools_cleanup(device&: dst->device); |
| 5931 | } |
| 5932 | |
| 5933 | static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, uint32_t m, uint32_t n, uint32_t k, bool disable_split_k, const vk_pipeline& pipeline) { |
| 5934 | VK_LOG_DEBUG("ggml_vk_guess_split_k(" << m << ", " << n << ", " << k << ", " << disable_split_k << ")" ); |
| 5935 | |
| 5936 | if (disable_split_k) { |
| 5937 | return 1; |
| 5938 | } |
| 5939 | |
| 5940 | uint32_t split_k = 1; |
| 5941 | if (ctx->device->shader_core_count != 0 && m >= pipeline->wg_denoms[0] && n >= pipeline->wg_denoms[1]) { |
| 5942 | // If k is 'large' and the SMs will fill less than halfway, use split_k. |
| 5943 | uint32_t m_tiles = CEIL_DIV(m, pipeline->wg_denoms[0]); |
| 5944 | uint32_t n_tiles = CEIL_DIV(n, pipeline->wg_denoms[1]); |
| 5945 | |
| 5946 | if (k >= 2048) { |
| 5947 | if (m_tiles * n_tiles <= ctx->device->shader_core_count / 2) { |
| 5948 | split_k = ctx->device->shader_core_count / (m_tiles * n_tiles); |
| 5949 | } else if (m_tiles * n_tiles <= ctx->device->shader_core_count * 2 / 3) { |
| 5950 | split_k = 3; |
| 5951 | } |
| 5952 | // Cap the split at 8x. Unless k is huge this is a lot of overhead. |
| 5953 | split_k = std::min(a: split_k, b: 8u); |
| 5954 | |
| 5955 | // ggml_vk_matmul will align the splits to be a multiple of 256. |
| 5956 | // If this rounded up size would cause the last split to be empty, |
| 5957 | // then reduce the split count. |
| 5958 | while (true) { |
| 5959 | if (split_k == 1) { |
| 5960 | break; |
| 5961 | } |
| 5962 | uint32_t k_split = CEIL_DIV(k, split_k); |
| 5963 | k_split = ROUNDUP_POW2(k_split, 256); |
| 5964 | if (k_split * (split_k - 1) < k) { |
| 5965 | break; |
| 5966 | } |
| 5967 | split_k--; |
| 5968 | } |
| 5969 | } |
| 5970 | } |
| 5971 | |
| 5972 | return split_k; |
| 5973 | } |
| 5974 | |
| 5975 | static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type, ggml_type src1_type) { |
| 5976 | VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")" ); |
| 5977 | |
| 5978 | if (ctx->device->coopmat2) { |
| 5979 | const uint32_t shader_core_count = ctx->device->shader_core_count; |
| 5980 | const uint32_t tiles_l = CEIL_DIV(m, mmp->a_l->wg_denoms[0]) * CEIL_DIV(n, mmp->a_l->wg_denoms[1]); |
| 5981 | const uint32_t tiles_m = CEIL_DIV(m, mmp->a_m->wg_denoms[0]) * CEIL_DIV(n, mmp->a_m->wg_denoms[1]); |
| 5982 | |
| 5983 | // Use large shader when the N dimension is greater than the medium shader's tile size |
| 5984 | uint32_t crossover_large = mmp->m->wg_denoms[1]; |
| 5985 | |
| 5986 | // Prefer large over medium if either: |
| 5987 | // - medium or large tiles would overfill the GPU |
| 5988 | // - large tiles with a split_k==3 fits in the GPU and medium tiles with split_k==2 does not |
| 5989 | // (medium with split_k==2 is probably better if it fits - more workgroups running and less split_k overhead) |
| 5990 | bool prefer_large = tiles_m > shader_core_count || tiles_l > shader_core_count || |
| 5991 | // split_k==3 with large tiles likely better than medium tiles with no split_k. |
| 5992 | (tiles_l <= shader_core_count / 3 && tiles_m > shader_core_count / 2); |
| 5993 | |
| 5994 | if ((ctx->device->mul_mat_l[src0_type] && (n > crossover_large && prefer_large)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_s[src0_type])) { |
| 5995 | return aligned ? mmp->a_l : mmp->l; |
| 5996 | } |
| 5997 | // Use medium shader when the N dimension is greater than the small shader's tile size |
| 5998 | uint32_t crossover_medium = mmp->s->wg_denoms[1]; |
| 5999 | if ((ctx->device->mul_mat_m[src0_type] && (n > crossover_medium)) || !ctx->device->mul_mat_s[src0_type]) { |
| 6000 | return aligned ? mmp->a_m : mmp->m; |
| 6001 | } |
| 6002 | return aligned ? mmp->a_s : mmp->s; |
| 6003 | } |
| 6004 | |
| 6005 | if ((ctx->device->mul_mat_s[src0_type] && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_l[src0_type])) { |
| 6006 | return aligned ? mmp->a_s : mmp->s; |
| 6007 | } |
| 6008 | if ((ctx->device->mul_mat_m[src0_type] && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_l[src0_type]) { |
| 6009 | return aligned ? mmp->a_m : mmp->m; |
| 6010 | } |
| 6011 | return aligned ? mmp->a_l : mmp->l; |
| 6012 | |
| 6013 | GGML_UNUSED(src1_type); |
| 6014 | } |
| 6015 | |
| 6016 | static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type, ggml_type src1_type) { |
| 6017 | VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ", " << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")" ); |
| 6018 | return ggml_vk_guess_matmul_pipeline(ctx, mmp, m, n, aligned: true, src0_type, src1_type)->align; |
| 6019 | } |
| 6020 | |
| 6021 | static void ggml_vk_matmul( |
| 6022 | ggml_backend_vk_context * ctx, vk_context& subctx, vk_pipeline& pipeline, |
| 6023 | vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& split_k_buffer, |
| 6024 | uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d, |
| 6025 | uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d, |
| 6026 | uint32_t split_k, uint32_t batch, uint32_t ne02, uint32_t ne12, uint32_t broadcast2, uint32_t broadcast3, |
| 6027 | uint32_t padded_n) { |
| 6028 | VK_LOG_DEBUG("ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ", padded_n: " << padded_n << ")" ); |
| 6029 | if (split_k == 1) { |
| 6030 | const vk_mat_mat_push_constants pc = { .M: m, .N: n, .K: k, .stride_a: stride_a, .stride_b: stride_b, .stride_d: stride_d, .batch_stride_a: batch_stride_a, .batch_stride_b: batch_stride_b, .batch_stride_d: batch_stride_d, .k_split: k, .ne02: ne02, .ne12: ne12, .broadcast2: broadcast2, .broadcast3: broadcast3, .padded_N: padded_n }; |
| 6031 | ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, descriptor_buffer_infos: { a, b, d }, push_constants: pc, elements: { m, n, batch }); |
| 6032 | return; |
| 6033 | } |
| 6034 | |
| 6035 | if (ctx->prealloc_split_k_need_sync) { |
| 6036 | ggml_vk_sync_buffers(ctx, subctx); |
| 6037 | } |
| 6038 | |
| 6039 | GGML_ASSERT(batch_stride_d == m * n); |
| 6040 | |
| 6041 | // Round the split size up to a multiple of 256 (k-quant alignment) |
| 6042 | uint32_t k_split = CEIL_DIV(k, split_k); |
| 6043 | k_split = ROUNDUP_POW2(k_split, 256); |
| 6044 | |
| 6045 | const vk_mat_mat_push_constants pc1 = { .M: m, .N: n, .K: k, .stride_a: stride_a, .stride_b: stride_b, .stride_d: stride_d, .batch_stride_a: batch_stride_a, .batch_stride_b: batch_stride_b, .batch_stride_d: batch_stride_d, .k_split: k_split, .ne02: ne02, .ne12: ne12, .broadcast2: broadcast2, .broadcast3: broadcast3, .padded_N: padded_n }; |
| 6046 | // Make sure enough workgroups get assigned for split k to work |
| 6047 | ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, descriptor_buffer_infos: { a, b, split_k_buffer }, push_constants: pc1, elements: { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, batch }); |
| 6048 | ggml_vk_sync_buffers(ctx, subctx); |
| 6049 | const std::array<uint32_t, 2> pc2 = { (uint32_t)(m * n * batch), split_k }; |
| 6050 | ggml_vk_dispatch_pipeline(ctx, subctx, pipeline&: ctx->device->pipeline_matmul_split_k_reduce, descriptor_buffer_infos: { split_k_buffer, d }, push_constants: pc2, elements: { m * n * batch, 1, 1 }); |
| 6051 | ctx->prealloc_split_k_need_sync = true; |
| 6052 | } |
| 6053 | |
| 6054 | static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type) { |
| 6055 | VK_LOG_DEBUG("ggml_vk_guess_matmul_id_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")" ); |
| 6056 | |
| 6057 | if (ctx->device->coopmat2) { |
| 6058 | // Use large shader when the N dimension is greater than the medium shader's tile size |
| 6059 | uint32_t crossover_large = mmp->m->wg_denoms[1]; |
| 6060 | if ((ctx->device->mul_mat_id_l[src0_type] && (n > crossover_large)) || (!ctx->device->mul_mat_id_m[src0_type] && !ctx->device->mul_mat_id_s[src0_type])) { |
| 6061 | return aligned ? mmp->a_l : mmp->l; |
| 6062 | } |
| 6063 | // Use medium shader when the N dimension is greater than the small shader's tile size |
| 6064 | uint32_t crossover_medium = mmp->s->wg_denoms[1]; |
| 6065 | if ((ctx->device->mul_mat_id_m[src0_type] && (n > crossover_medium)) || !ctx->device->mul_mat_id_s[src0_type]) { |
| 6066 | return aligned ? mmp->a_m : mmp->m; |
| 6067 | } |
| 6068 | return aligned ? mmp->a_s : mmp->s; |
| 6069 | } |
| 6070 | |
| 6071 | if ((ctx->device->mul_mat_id_s[src0_type] && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_id_m[src0_type] && !ctx->device->mul_mat_id_l[src0_type])) { |
| 6072 | return aligned ? mmp->a_s : mmp->s; |
| 6073 | } |
| 6074 | if ((ctx->device->mul_mat_id_m[src0_type] && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_id_l[src0_type]) { |
| 6075 | return aligned ? mmp->a_m : mmp->m; |
| 6076 | } |
| 6077 | return aligned ? mmp->a_l : mmp->l; |
| 6078 | } |
| 6079 | |
| 6080 | static uint32_t ggml_vk_guess_matmul_id_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type) { |
| 6081 | VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ", " << ggml_type_name(src0_type) << ")" ); |
| 6082 | return ggml_vk_guess_matmul_id_pipeline(ctx, mmp, m, n, aligned: true, src0_type)->align; |
| 6083 | } |
| 6084 | |
| 6085 | static void ggml_vk_matmul_id( |
| 6086 | ggml_backend_vk_context * ctx, vk_context& subctx, vk_pipeline& pipeline, |
| 6087 | vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& ids, |
| 6088 | uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d, |
| 6089 | uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d, |
| 6090 | uint32_t n_as, uint32_t nei0, uint32_t nei1, uint32_t nbi1, uint32_t ne11, |
| 6091 | uint32_t padded_n) { |
| 6092 | VK_LOG_DEBUG("ggml_vk_matmul_id(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), ids: (" << ids.buffer->buffer << ", " << ids.offset << ", " << ids.size << "), " << |
| 6093 | "m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", " << |
| 6094 | "batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", " << |
| 6095 | "n_as: " << n_as << ", nei0: " << nei0 << ", nei1: " << nei1 << ", nbi1: " << nbi1 << ", ne11: " << ne11 << ")" ); |
| 6096 | const vk_mat_mat_id_push_constants pc = { .M: m, .N: n, .K: k, .stride_a: stride_a, .stride_b: stride_b, .stride_d: stride_d, .batch_stride_a: batch_stride_a, .batch_stride_b: batch_stride_b, .batch_stride_d: batch_stride_d, |
| 6097 | .nei0: nei0, .nei1: nei1, .nbi1: nbi1, .ne11: ne11, .padded_N: padded_n }; |
| 6098 | ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, descriptor_buffer_infos: { a, b, d, ids }, push_constants: pc, elements: { m, nei1, n_as }); |
| 6099 | } |
| 6100 | |
| 6101 | static bool ggml_vk_dim01_contiguous(const ggml_tensor * tensor) { |
| 6102 | return |
| 6103 | tensor->nb[0] == ggml_type_size(type: tensor->type) && |
| 6104 | tensor->nb[1] == (tensor->nb[0]*tensor->ne[0])/ggml_blck_size(type: tensor->type) && |
| 6105 | (tensor->ne[3] == 1 || tensor->nb[3] == tensor->nb[2]*tensor->ne[2]); |
| 6106 | } |
| 6107 | |
| 6108 | static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src, const ggml_tensor * dst, ggml_type to) { |
| 6109 | |
| 6110 | // Choose "contiguous copy" shader if src/dst are contiguous |
| 6111 | bool contig = ggml_is_contiguous(tensor: src) && (!dst || ggml_is_contiguous(tensor: dst)); |
| 6112 | |
| 6113 | if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_F32) { |
| 6114 | if (contig) { |
| 6115 | return ctx->device->pipeline_contig_cpy_f32_f32; |
| 6116 | } else { |
| 6117 | return ctx->device->pipeline_cpy_f32_f32; |
| 6118 | } |
| 6119 | } |
| 6120 | if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_F16) { |
| 6121 | if (contig) { |
| 6122 | return ctx->device->pipeline_contig_cpy_f32_f16; |
| 6123 | } else { |
| 6124 | return ctx->device->pipeline_cpy_f32_f16; |
| 6125 | } |
| 6126 | } |
| 6127 | if (src->type == GGML_TYPE_F16 && to == GGML_TYPE_F16) { |
| 6128 | if (contig) { |
| 6129 | return ctx->device->pipeline_contig_cpy_f16_f16; |
| 6130 | } else { |
| 6131 | return ctx->device->pipeline_cpy_f16_f16; |
| 6132 | } |
| 6133 | } |
| 6134 | if (src->type == GGML_TYPE_F16 && to == GGML_TYPE_F32) { |
| 6135 | if (contig) { |
| 6136 | return ctx->device->pipeline_contig_cpy_f16_f32; |
| 6137 | } else { |
| 6138 | return ctx->device->pipeline_cpy_f16_f32; |
| 6139 | } |
| 6140 | } |
| 6141 | if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_BF16) { |
| 6142 | if (contig) { |
| 6143 | return ctx->device->pipeline_contig_cpy_f32_bf16; |
| 6144 | } else { |
| 6145 | return ctx->device->pipeline_cpy_f32_bf16; |
| 6146 | } |
| 6147 | } |
| 6148 | if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_I32) { |
| 6149 | if (contig) { |
| 6150 | return ctx->device->pipeline_contig_cpy_f32_i32; |
| 6151 | } else { |
| 6152 | return ctx->device->pipeline_cpy_f32_i32; |
| 6153 | } |
| 6154 | } |
| 6155 | if (src->type == GGML_TYPE_I32 && to == GGML_TYPE_F32) { |
| 6156 | if (contig) { |
| 6157 | return ctx->device->pipeline_contig_cpy_i32_f32; |
| 6158 | } else { |
| 6159 | return ctx->device->pipeline_cpy_i32_f32; |
| 6160 | } |
| 6161 | } |
| 6162 | if (src->type == GGML_TYPE_F32) { |
| 6163 | switch (to) { |
| 6164 | case GGML_TYPE_Q4_0: |
| 6165 | case GGML_TYPE_Q4_1: |
| 6166 | case GGML_TYPE_Q5_0: |
| 6167 | case GGML_TYPE_Q5_1: |
| 6168 | case GGML_TYPE_Q8_0: |
| 6169 | case GGML_TYPE_IQ4_NL: |
| 6170 | return ctx->device->pipeline_cpy_f32_quant[to]; |
| 6171 | default: |
| 6172 | break; |
| 6173 | } |
| 6174 | } |
| 6175 | |
| 6176 | if (to == GGML_TYPE_F32) { |
| 6177 | switch (src->type) { |
| 6178 | case GGML_TYPE_Q4_0: |
| 6179 | case GGML_TYPE_Q4_1: |
| 6180 | case GGML_TYPE_Q5_0: |
| 6181 | case GGML_TYPE_Q5_1: |
| 6182 | case GGML_TYPE_Q8_0: |
| 6183 | case GGML_TYPE_IQ4_NL: |
| 6184 | return ctx->device->pipeline_cpy_quant_f32[src->type]; |
| 6185 | default: |
| 6186 | break; |
| 6187 | } |
| 6188 | } |
| 6189 | |
| 6190 | if (src->type == to) { |
| 6191 | // Copy two or four bytes at a time, depending on block size. |
| 6192 | // For quantized types, we scale by block size/type size. But |
| 6193 | // this path is also used for bf16->bf16 for example, where the |
| 6194 | // type size must be exactly 2 or 4. |
| 6195 | GGML_ASSERT(ggml_is_quantized(to) || ggml_type_size(src->type) == 2 || ggml_type_size(src->type) == 4); |
| 6196 | if ((ggml_type_size(type: src->type) % 4) == 0) { |
| 6197 | if (contig) { |
| 6198 | return ctx->device->pipeline_contig_cpy_f32_f32; |
| 6199 | } else { |
| 6200 | return ctx->device->pipeline_cpy_f32_f32; |
| 6201 | } |
| 6202 | } else { |
| 6203 | if (contig) { |
| 6204 | return ctx->device->pipeline_contig_cpy_f16_f16; |
| 6205 | } else { |
| 6206 | return ctx->device->pipeline_cpy_f16_f16; |
| 6207 | } |
| 6208 | } |
| 6209 | } |
| 6210 | |
| 6211 | std::cerr << "Missing CPY op for types: " << ggml_type_name(type: src->type) << " " << ggml_type_name(type: to) << std::endl; |
| 6212 | GGML_ABORT("fatal error" ); |
| 6213 | } |
| 6214 | |
| 6215 | static void ggml_vk_cpy_to_contiguous(ggml_backend_vk_context * ctx, vk_context& subctx, vk_pipeline pipeline, const ggml_tensor * tensor, vk_subbuffer&& in, vk_subbuffer&& out) { |
| 6216 | VK_LOG_DEBUG("ggml_vk_cpy_to_contiguous((" << tensor << ", type=" << tensor->type << ", ne0=" << tensor->ne[0] << ", ne1=" << tensor->ne[1] << ", ne2=" << tensor->ne[2] << ", ne3=" << tensor->ne[3] << ", nb0=" << tensor->nb[0] << ", nb1=" << tensor->nb[1] << ", nb2=" << tensor->nb[2] << ", nb3=" << tensor->nb[3] << "), " ; |
| 6217 | std::cerr << "buffer in size=" << in.buffer->size << ", buffer out size=" << out.buffer->size << ")" ); |
| 6218 | const int tensor_type_size = ggml_type_size(type: tensor->type); |
| 6219 | |
| 6220 | const uint32_t ne = ggml_nelements(tensor); |
| 6221 | std::array<uint32_t, 3> elements; |
| 6222 | |
| 6223 | if (ne > 262144) { |
| 6224 | elements = { 512, 512, CEIL_DIV(ne, 262144) }; |
| 6225 | } else if (ne > 512) { |
| 6226 | elements = { 512, CEIL_DIV(ne, 512), 1 }; |
| 6227 | } else { |
| 6228 | elements = { ne, 1, 1 }; |
| 6229 | } |
| 6230 | |
| 6231 | vk_op_unary_push_constants pc = { |
| 6232 | .ne: (uint32_t)ne, |
| 6233 | .ne00: (uint32_t)tensor->ne[0], .ne01: (uint32_t)tensor->ne[1], .ne02: (uint32_t)tensor->ne[2], .ne03: (uint32_t)tensor->ne[3], .nb00: (uint32_t)tensor->nb[0] / tensor_type_size, .nb01: (uint32_t)tensor->nb[1] / tensor_type_size, .nb02: (uint32_t)tensor->nb[2] / tensor_type_size, .nb03: (uint32_t)tensor->nb[3] / tensor_type_size, |
| 6234 | .ne10: (uint32_t)tensor->ne[0], .ne11: (uint32_t)tensor->ne[1], .ne12: (uint32_t)tensor->ne[2], .ne13: (uint32_t)tensor->ne[3], .nb10: 1 , .nb11: (uint32_t)tensor->ne[0] , .nb12: (uint32_t)(tensor->ne[0] * tensor->ne[1]) , .nb13: (uint32_t)(tensor->ne[0] * tensor->ne[1] * tensor->ne[2]), |
| 6235 | .misalign_offsets: 0, |
| 6236 | .param1: 0.0f, .param2: 0.0f, |
| 6237 | .ne0_012mp: 0, .ne0_012L: 0, .ne0_01mp: 0, .ne0_01L: 0, .ne0_0mp: 0, .ne0_0L: 0, .ne1_012mp: 0, .ne1_012L: 0, .ne1_01mp: 0, .ne1_01L: 0, .ne1_0mp: 0, .ne1_0L: 0, |
| 6238 | }; |
| 6239 | init_pushconst_fastdiv(p&: pc); |
| 6240 | ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, descriptor_buffer_infos: { in, out }, push_constants: pc, elements); |
| 6241 | ggml_vk_sync_buffers(ctx, subctx); |
| 6242 | } |
| 6243 | |
| 6244 | static vk_pipeline ggml_vk_get_quantize_pipeline(ggml_backend_vk_context * ctx, ggml_type type, bool use_x4_blocks) { |
| 6245 | switch(type) { |
| 6246 | case GGML_TYPE_Q8_1: |
| 6247 | return use_x4_blocks ? ctx->device->pipeline_quantize_q8_1_x4 : ctx->device->pipeline_quantize_q8_1; |
| 6248 | default: |
| 6249 | std::cerr << "Missing quantize pipeline for type: " << ggml_type_name(type) << std::endl; |
| 6250 | GGML_ABORT("fatal error" ); |
| 6251 | } |
| 6252 | } |
| 6253 | |
| 6254 | static void ggml_vk_quantize_q8_1(ggml_backend_vk_context * ctx, vk_context& subctx, vk_subbuffer&& in, vk_subbuffer&& out, uint32_t ne, bool use_x4_blocks = false) { |
| 6255 | VK_LOG_DEBUG("ggml_vk_quantize_q8_1(" << "buffer in size=" << in.buffer->size << ", buffer out size=" << out.buffer->size << ", " << ne << ")" ); |
| 6256 | |
| 6257 | vk_pipeline pipeline = use_x4_blocks ? ggml_vk_get_quantize_pipeline(ctx, type: GGML_TYPE_Q8_1, use_x4_blocks: true) : ggml_vk_get_quantize_pipeline(ctx, type: GGML_TYPE_Q8_1, use_x4_blocks: false); |
| 6258 | |
| 6259 | ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, descriptor_buffer_infos: { in, out }, push_constants: std::array<uint32_t, 1>{ne}, elements: { ne, 1, 1 }); |
| 6260 | ggml_vk_sync_buffers(ctx, subctx); |
| 6261 | } |
| 6262 | |
| 6263 | static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool disable_split_k) { |
| 6264 | VK_LOG_DEBUG("ggml_vk_mul_mat_q_f16((" << src0 << ", name=" << src0->name << ", type=" << ggml_type_name(src0->type) << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; |
| 6265 | std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << ggml_type_name(src1->type) << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; |
| 6266 | std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << ggml_type_name(dst->type) << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; |
| 6267 | std::cerr << "))" ); |
| 6268 | GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16); // NOLINT |
| 6269 | GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT |
| 6270 | |
| 6271 | const uint64_t ne00 = src0->ne[0]; |
| 6272 | const uint64_t ne01 = src0->ne[1]; |
| 6273 | const uint64_t ne02 = src0->ne[2]; |
| 6274 | const uint64_t ne03 = src0->ne[3]; |
| 6275 | |
| 6276 | const uint64_t ne10 = src1->ne[0]; |
| 6277 | const uint64_t ne11 = src1->ne[1]; |
| 6278 | const uint64_t ne12 = src1->ne[2]; |
| 6279 | const uint64_t ne13 = src1->ne[3]; |
| 6280 | |
| 6281 | const uint64_t ne21 = dst->ne[1]; |
| 6282 | const uint32_t stride_d = dst->nb[1] / ggml_type_size(type: dst->type); |
| 6283 | const uint32_t stride_batch_d = stride_d*ne21; |
| 6284 | |
| 6285 | const uint64_t r2 = ne12 / ne02; |
| 6286 | const uint64_t r3 = ne13 / ne03; |
| 6287 | |
| 6288 | ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; |
| 6289 | ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; |
| 6290 | ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context; |
| 6291 | |
| 6292 | vk_buffer d_Qx = nullptr; |
| 6293 | size_t qx_buf_offset = 0; |
| 6294 | vk_buffer d_Qy = nullptr; |
| 6295 | size_t qy_buf_offset = 0; |
| 6296 | |
| 6297 | bool src0_uma = false; |
| 6298 | bool src1_uma = false; |
| 6299 | |
| 6300 | if (ctx->device->uma) { |
| 6301 | ggml_vk_host_get(device: ctx->device, ptr: src0->data, buf&: d_Qx, buf_offset&: qx_buf_offset); |
| 6302 | ggml_vk_host_get(device: ctx->device, ptr: src1->data, buf&: d_Qy, buf_offset&: qy_buf_offset); |
| 6303 | src0_uma = d_Qx != nullptr; |
| 6304 | src1_uma = d_Qy != nullptr; |
| 6305 | } |
| 6306 | |
| 6307 | // Reformat and convert to fp16 if non-contiguous, or for coopmat2 for better perf |
| 6308 | const bool x_non_contig = (ctx->device->coopmat2 && src0->type == GGML_TYPE_F32) || |
| 6309 | !ggml_vk_dim01_contiguous(tensor: src0); |
| 6310 | const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) || |
| 6311 | (src0->type == GGML_TYPE_BF16 && src1->type != GGML_TYPE_BF16) || |
| 6312 | !ggml_vk_dim01_contiguous(tensor: src1); |
| 6313 | |
| 6314 | // If src0 is BF16, try to use a BF16 x BF16 multiply |
| 6315 | ggml_type f16_type = src0->type == GGML_TYPE_BF16 ? GGML_TYPE_BF16 : GGML_TYPE_F16; |
| 6316 | |
| 6317 | const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig; |
| 6318 | |
| 6319 | bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(tensor: src1) && (ne11 * ne10) % 4 == 0; |
| 6320 | |
| 6321 | // Check for mmq first |
| 6322 | vk_matmul_pipeline mmp = quantize_y ? ggml_vk_get_mul_mat_mat_pipeline(ctx, src0_type: src0->type, src1_type: GGML_TYPE_Q8_1, prec: (ggml_prec)dst->op_params[0]) : nullptr; |
| 6323 | |
| 6324 | if (mmp == nullptr) { |
| 6325 | // Fall back to f16 dequant mul mat |
| 6326 | mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, src0_type: src0->type, src1_type: y_non_contig ? f16_type : src1->type, prec: (ggml_prec)dst->op_params[0]); |
| 6327 | quantize_y = false; |
| 6328 | } |
| 6329 | |
| 6330 | const bool qx_needs_dequant = mmp == nullptr || x_non_contig; |
| 6331 | const bool qy_needs_dequant = !quantize_y && ((src1->type != f16_type && !y_f32_kernel) || y_non_contig); |
| 6332 | |
| 6333 | if (qx_needs_dequant) { |
| 6334 | // Fall back to dequant + f16 mulmat |
| 6335 | mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, src0_type: f16_type, src1_type: y_f32_kernel ? GGML_TYPE_F32 : f16_type, prec: (ggml_prec)dst->op_params[0]); |
| 6336 | } |
| 6337 | |
| 6338 | // Not implemented |
| 6339 | GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT |
| 6340 | |
| 6341 | const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(width: ne10, align: ggml_vk_guess_matmul_pipeline_align(ctx, mmp, m: ne01, n: ne11, src0_type: qx_needs_dequant ? f16_type : src0->type, src1_type: quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type))); |
| 6342 | const bool aligned = !quantize_y && ne10 == kpad && ne01 > 8 && ne11 > 8; |
| 6343 | |
| 6344 | vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, m: ne01, n: ne11, aligned, src0_type: qx_needs_dequant ? f16_type : src0->type, src1_type: quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type)); |
| 6345 | |
| 6346 | // Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking |
| 6347 | uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) : ne11; |
| 6348 | const int x_ne = ne01 * ne00; |
| 6349 | const int y_ne = padded_n * ne10; |
| 6350 | const int d_ne = ne11 * ne01; |
| 6351 | |
| 6352 | const uint32_t split_k = ggml_vk_guess_split_k(ctx, m: ne01, n: ne11, k: ne10, disable_split_k, pipeline); |
| 6353 | |
| 6354 | const uint64_t qx_sz = ggml_type_size(type: src0->type) * x_ne / ggml_blck_size(type: src0->type); |
| 6355 | const uint64_t qy_sz = ggml_type_size(type: src1->type) * y_ne / ggml_blck_size(type: src1->type); |
| 6356 | const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne; |
| 6357 | const uint64_t y_sz = quantize_y ? (y_ne * ggml_type_size(type: GGML_TYPE_Q8_1) / ggml_blck_size(type: GGML_TYPE_Q8_1)) : (y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne); |
| 6358 | const uint64_t d_sz = sizeof(float) * d_ne; |
| 6359 | |
| 6360 | vk_pipeline to_fp16_vk_0 = nullptr; |
| 6361 | vk_pipeline to_fp16_vk_1 = nullptr; |
| 6362 | vk_pipeline to_q8_1 = nullptr; |
| 6363 | |
| 6364 | if (x_non_contig) { |
| 6365 | to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src: src0, dst: nullptr, to: f16_type); |
| 6366 | } else { |
| 6367 | to_fp16_vk_0 = ggml_vk_get_to_fp16(ctx, type: src0->type); |
| 6368 | } |
| 6369 | if (y_non_contig) { |
| 6370 | to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src: src1, dst: nullptr, to: f16_type); |
| 6371 | } else { |
| 6372 | to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, type: src1->type); |
| 6373 | } |
| 6374 | GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT |
| 6375 | GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT |
| 6376 | |
| 6377 | if (quantize_y) { |
| 6378 | to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, type: GGML_TYPE_Q8_1, use_x4_blocks: true); |
| 6379 | } |
| 6380 | |
| 6381 | { |
| 6382 | const uint64_t x_sz_upd = x_sz * ne02 * ne03; |
| 6383 | uint64_t y_sz_upd = y_sz * ne12 * ne13; |
| 6384 | if (quantize_y) { |
| 6385 | y_sz_upd = CEIL_DIV(y_sz_upd, 144) * 144; |
| 6386 | } |
| 6387 | const uint64_t split_k_size = split_k > 1 ? d_sz * ne12 * ne13 * split_k : 0; |
| 6388 | if ( |
| 6389 | (qx_needs_dequant && x_sz_upd > ctx->device->properties.limits.maxStorageBufferRange) || |
| 6390 | (qy_needs_dequant && y_sz_upd > ctx->device->properties.limits.maxStorageBufferRange) || |
| 6391 | (split_k > 1 && split_k_size > ctx->device->properties.limits.maxStorageBufferRange)) { |
| 6392 | GGML_ABORT("Requested preallocation size is too large" ); |
| 6393 | } |
| 6394 | if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) { |
| 6395 | ctx->prealloc_size_x = x_sz_upd; |
| 6396 | ggml_vk_preallocate_buffers(ctx, subctx); |
| 6397 | } |
| 6398 | if ((qy_needs_dequant || quantize_y) && ctx->prealloc_size_y < y_sz_upd) { |
| 6399 | ctx->prealloc_size_y = y_sz_upd; |
| 6400 | ggml_vk_preallocate_buffers(ctx, subctx); |
| 6401 | } |
| 6402 | if (split_k > 1 && ctx->prealloc_size_split_k < split_k_size) { |
| 6403 | ctx->prealloc_size_split_k = split_k_size; |
| 6404 | ggml_vk_preallocate_buffers(ctx, subctx); |
| 6405 | } |
| 6406 | |
| 6407 | // Request descriptor sets |
| 6408 | ggml_pipeline_request_descriptor_sets(ctx, pipeline, n: 1); |
| 6409 | if (qx_needs_dequant) { |
| 6410 | ggml_pipeline_request_descriptor_sets(ctx, pipeline&: to_fp16_vk_0, n: 1); |
| 6411 | } |
| 6412 | if (qy_needs_dequant) { |
| 6413 | ggml_pipeline_request_descriptor_sets(ctx, pipeline&: to_fp16_vk_1, n: 1); |
| 6414 | } |
| 6415 | if (quantize_y) { |
| 6416 | ggml_pipeline_request_descriptor_sets(ctx, pipeline&: to_q8_1, n: 1); |
| 6417 | } |
| 6418 | if (split_k > 1) { |
| 6419 | ggml_pipeline_request_descriptor_sets(ctx, pipeline&: ctx->device->pipeline_matmul_split_k_reduce, n: 1); |
| 6420 | } |
| 6421 | } |
| 6422 | |
| 6423 | vk_buffer d_D = dst_buf_ctx->dev_buffer; |
| 6424 | const uint64_t d_buf_offset = vk_tensor_offset(tensor: dst) + dst->view_offs; |
| 6425 | GGML_ASSERT(d_D != nullptr); |
| 6426 | GGML_ASSERT(d_D->size >= d_buf_offset + d_sz * ne02 * ne03); |
| 6427 | vk_buffer d_X; |
| 6428 | uint64_t x_buf_offset = 0; |
| 6429 | vk_buffer d_Y; |
| 6430 | uint64_t y_buf_offset = 0; |
| 6431 | if (!src0_uma) { |
| 6432 | d_Qx = src0_buf_ctx->dev_buffer; |
| 6433 | qx_buf_offset = vk_tensor_offset(tensor: src0) + src0->view_offs; |
| 6434 | GGML_ASSERT(d_Qx != nullptr); |
| 6435 | } |
| 6436 | if (!src1_uma) { |
| 6437 | d_Qy = src1_buf_ctx->dev_buffer; |
| 6438 | qy_buf_offset = vk_tensor_offset(tensor: src1) + src1->view_offs; |
| 6439 | GGML_ASSERT(d_Qy != nullptr); |
| 6440 | } |
| 6441 | if (qx_needs_dequant) { |
| 6442 | d_X = ctx->prealloc_x; |
| 6443 | GGML_ASSERT(d_X->size >= x_sz * ne02 * ne03); |
| 6444 | } else { |
| 6445 | d_X = d_Qx; |
| 6446 | x_buf_offset = qx_buf_offset; |
| 6447 | GGML_ASSERT(qx_sz == x_sz); |
| 6448 | } |
| 6449 | if (qy_needs_dequant) { |
| 6450 | d_Y = ctx->prealloc_y; |
| 6451 | GGML_ASSERT(d_Y->size >= y_sz * ne12 * ne13); |
| 6452 | } else if (quantize_y) { |
| 6453 | d_Y = ctx->prealloc_y; |
| 6454 | GGML_ASSERT(d_Y->size >= CEIL_DIV(y_sz * ne12 * ne13, 144) * 144); |
| 6455 | } else { |
| 6456 | d_Y = d_Qy; |
| 6457 | y_buf_offset = qy_buf_offset; |
| 6458 | GGML_ASSERT(qy_sz == y_sz); |
| 6459 | } |
| 6460 | |
| 6461 | if (x_non_contig || qx_needs_dequant) { |
| 6462 | if (ctx->prealloc_x_need_sync) { |
| 6463 | ggml_vk_sync_buffers(ctx, subctx); |
| 6464 | } |
| 6465 | } |
| 6466 | |
| 6467 | if (x_non_contig) { |
| 6468 | ggml_vk_cpy_to_contiguous(ctx, subctx, pipeline: to_fp16_vk_0, tensor: src0, in: ggml_vk_subbuffer(ctx, buf: d_Qx, offset: qx_buf_offset), out: ggml_vk_subbuffer(ctx, buf: d_X, offset: 0)); |
| 6469 | } else if (qx_needs_dequant) { |
| 6470 | const std::vector<uint32_t> pc = { (uint32_t)ne01, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(ggml_nelements(tensor: src0)) }; |
| 6471 | ggml_vk_dispatch_pipeline(ctx, subctx, pipeline&: to_fp16_vk_0, descriptor_buffer_infos: { vk_subbuffer{ .buffer: d_Qx, .offset: qx_buf_offset, .size: qx_sz * ne02 * ne03 }, vk_subbuffer{ .buffer: d_X, .offset: 0, .size: x_sz * ne02 * ne03 } }, push_constants: pc, elements: { (uint32_t)(x_ne * ne02 * ne03), 1, 1}); |
| 6472 | ggml_vk_sync_buffers(ctx, subctx); |
| 6473 | } |
| 6474 | if (y_non_contig) { |
| 6475 | if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() || |
| 6476 | ctx->prealloc_y_last_tensor_used != src1) { |
| 6477 | if (ctx->prealloc_y_need_sync) { |
| 6478 | ggml_vk_sync_buffers(ctx, subctx); |
| 6479 | } |
| 6480 | ggml_vk_cpy_to_contiguous(ctx, subctx, pipeline: to_fp16_vk_1, tensor: src1, in: ggml_vk_subbuffer(ctx, buf: d_Qy, offset: qy_buf_offset), out: ggml_vk_subbuffer(ctx, buf: d_Y, offset: 0)); |
| 6481 | ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get(); |
| 6482 | ctx->prealloc_y_last_tensor_used = src1; |
| 6483 | } |
| 6484 | } |
| 6485 | if (quantize_y) { |
| 6486 | if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() || |
| 6487 | ctx->prealloc_y_last_tensor_used != src1) { |
| 6488 | if (ctx->prealloc_y_need_sync) { |
| 6489 | ggml_vk_sync_buffers(ctx, subctx); |
| 6490 | } |
| 6491 | ggml_vk_quantize_q8_1(ctx, subctx, in: ggml_vk_subbuffer(ctx, buf: d_Qy, offset: qy_buf_offset), out: ggml_vk_subbuffer(ctx, buf: d_Y, offset: 0), ne: y_ne * ne12 * ne13, use_x4_blocks: true); |
| 6492 | ctx->prealloc_y_last_pipeline_used = to_q8_1.get(); |
| 6493 | ctx->prealloc_y_last_tensor_used = src1; |
| 6494 | } |
| 6495 | } |
| 6496 | |
| 6497 | uint32_t stride_batch_x = ne00*ne01; |
| 6498 | uint32_t stride_batch_y = ne10*ne11; |
| 6499 | |
| 6500 | if (!ggml_vk_dim01_contiguous(tensor: src0) && !qx_needs_dequant) { |
| 6501 | stride_batch_x = src0->nb[0] / ggml_type_size(type: src0->type); |
| 6502 | } |
| 6503 | |
| 6504 | if (!ggml_vk_dim01_contiguous(tensor: src1) && !qy_needs_dequant && !quantize_y) { |
| 6505 | stride_batch_y = src1->nb[0] / ggml_type_size(type: src1->type); |
| 6506 | } |
| 6507 | |
| 6508 | uint32_t y_sz_total = y_sz * ne12 * ne13; |
| 6509 | if (quantize_y) { |
| 6510 | y_sz_total = CEIL_DIV(y_sz_total, 144) * 144; |
| 6511 | } |
| 6512 | |
| 6513 | // compute |
| 6514 | ggml_vk_matmul( |
| 6515 | ctx, subctx, pipeline, |
| 6516 | a: { .buffer: d_X, .offset: x_buf_offset, .size: x_sz * ne02 * ne03 }, b: { .buffer: d_Y, .offset: y_buf_offset, .size: y_sz_total }, |
| 6517 | d: ggml_vk_subbuffer(ctx, buf: d_D, offset: d_buf_offset), split_k_buffer: { .buffer: ctx->prealloc_split_k, .offset: 0, .size: d_sz * ne12 * ne13 * split_k }, |
| 6518 | m: ne01, n: ne11, k: ne10, |
| 6519 | stride_a: ne10, stride_b: ne10, stride_d, batch_stride_a: stride_batch_x, batch_stride_b: stride_batch_y, batch_stride_d: stride_batch_d, |
| 6520 | split_k, batch: ne12*ne13, ne02, ne12, broadcast2: r2, broadcast3: r3, padded_n |
| 6521 | ); // NOLINT |
| 6522 | |
| 6523 | if (x_non_contig || qx_needs_dequant) { |
| 6524 | ctx->prealloc_x_need_sync = true; |
| 6525 | } |
| 6526 | if (y_non_contig || quantize_y) { |
| 6527 | ctx->prealloc_y_need_sync = true; |
| 6528 | } |
| 6529 | } |
| 6530 | |
| 6531 | // Device tuning |
| 6532 | static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_t n, uint32_t k, ggml_type src0_type) { |
| 6533 | if (device->mmvq_mode == 1) { |
| 6534 | return true; |
| 6535 | } else if (device->mmvq_mode == -1) { |
| 6536 | return false; |
| 6537 | } |
| 6538 | |
| 6539 | // MMVQ is generally good for batches |
| 6540 | if (n > 1) { |
| 6541 | return true; |
| 6542 | } |
| 6543 | |
| 6544 | switch (device->vendor_id) { |
| 6545 | case VK_VENDOR_ID_NVIDIA: |
| 6546 | switch (src0_type) { |
| 6547 | case GGML_TYPE_Q8_0: |
| 6548 | return device->architecture == vk_device_architecture::NVIDIA_PRE_TURING; |
| 6549 | default: |
| 6550 | return true; |
| 6551 | } |
| 6552 | case VK_VENDOR_ID_AMD: |
| 6553 | switch (src0_type) { |
| 6554 | case GGML_TYPE_Q8_0: |
| 6555 | return device->architecture == vk_device_architecture::AMD_GCN; |
| 6556 | default: |
| 6557 | return true; |
| 6558 | } |
| 6559 | case VK_VENDOR_ID_INTEL: |
| 6560 | switch (src0_type) { |
| 6561 | // From tests on A770 Linux, may need more tuning |
| 6562 | case GGML_TYPE_Q4_0: |
| 6563 | case GGML_TYPE_Q5_1: |
| 6564 | return false; |
| 6565 | default: |
| 6566 | return true; |
| 6567 | } |
| 6568 | default: |
| 6569 | return true; |
| 6570 | } |
| 6571 | |
| 6572 | GGML_UNUSED(m); |
| 6573 | GGML_UNUSED(k); |
| 6574 | } |
| 6575 | |
| 6576 | static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) { |
| 6577 | ggml_tensor * dst = cgraph->nodes[node_idx]; |
| 6578 | const ggml_tensor * src0 = dst->src[0]; |
| 6579 | const ggml_tensor * src1 = dst->src[1]; |
| 6580 | |
| 6581 | VK_LOG_DEBUG("ggml_vk_mul_mat_vec_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; |
| 6582 | std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; |
| 6583 | std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; |
| 6584 | std::cerr << ")),)" ); |
| 6585 | GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16); // NOLINT |
| 6586 | GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT |
| 6587 | |
| 6588 | const uint64_t ne00 = src0->ne[0]; |
| 6589 | const uint64_t ne01 = src0->ne[1]; |
| 6590 | const uint64_t ne02 = src0->ne[2]; |
| 6591 | const uint64_t ne03 = src0->ne[3]; |
| 6592 | |
| 6593 | const uint64_t ne10 = src1->ne[0]; |
| 6594 | const uint64_t ne11 = src1->ne[1]; |
| 6595 | const uint64_t ne12 = src1->ne[2]; |
| 6596 | const uint64_t ne13 = src1->ne[3]; |
| 6597 | |
| 6598 | const uint64_t ne20 = dst->ne[0]; |
| 6599 | const uint64_t ne21 = dst->ne[1]; |
| 6600 | const uint64_t ne22 = dst->ne[2]; |
| 6601 | const uint64_t ne23 = dst->ne[3]; |
| 6602 | |
| 6603 | const uint64_t r2 = ne12 / ne02; |
| 6604 | const uint64_t r3 = ne13 / ne03; |
| 6605 | |
| 6606 | // batch_n indicates that we need to compute a few vector results, and this assumes |
| 6607 | // ne12 and ne13 are 1. It overloads the batch_strides to hold the row strides. |
| 6608 | GGML_ASSERT(ne11 == 1 || ne12 * ne13 == 1); |
| 6609 | bool batch_n = ne11 > 1; |
| 6610 | |
| 6611 | ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; |
| 6612 | ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context; |
| 6613 | |
| 6614 | vk_buffer d_Qx = nullptr; |
| 6615 | size_t qx_buf_offset = 0; |
| 6616 | vk_buffer d_Qy = nullptr; |
| 6617 | size_t qy_buf_offset = 0; |
| 6618 | |
| 6619 | bool src0_uma = false; |
| 6620 | bool src1_uma = false; |
| 6621 | |
| 6622 | if (ctx->device->uma) { |
| 6623 | ggml_vk_host_get(device: ctx->device, ptr: src0->data, buf&: d_Qx, buf_offset&: qx_buf_offset); |
| 6624 | ggml_vk_host_get(device: ctx->device, ptr: src1->data, buf&: d_Qy, buf_offset&: qy_buf_offset); |
| 6625 | src0_uma = d_Qx != nullptr; |
| 6626 | src1_uma = d_Qy != nullptr; |
| 6627 | } |
| 6628 | |
| 6629 | const bool x_non_contig = !ggml_vk_dim01_contiguous(tensor: src0); |
| 6630 | const bool y_non_contig = !ggml_vk_dim01_contiguous(tensor: src1); |
| 6631 | |
| 6632 | const bool f16_f32_kernel = src1->type == GGML_TYPE_F32; |
| 6633 | bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(tensor: src1) && (ne11 * ne10) % 4 == 0 && ggml_vk_should_use_mmvq(device: ctx->device, m: ne01, n: ne11, k: ne10, src0_type: src0->type); |
| 6634 | |
| 6635 | vk_pipeline to_fp16_vk_0 = nullptr; |
| 6636 | vk_pipeline to_fp16_vk_1 = nullptr; |
| 6637 | if (x_non_contig) { |
| 6638 | to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src: src0, dst: nullptr, to: src0->type); |
| 6639 | } |
| 6640 | if (y_non_contig) { |
| 6641 | to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src: src1, dst: nullptr, to: src1->type); |
| 6642 | } else { |
| 6643 | to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, type: src1->type); |
| 6644 | } |
| 6645 | |
| 6646 | // Check for mmq first |
| 6647 | vk_pipeline dmmv = quantize_y ? ggml_vk_get_dequantize_mul_mat_vec(ctx, a_type: src0->type, b_type: GGML_TYPE_Q8_1, num_cols: ne11, m: ne20, k: ne00) : nullptr; |
| 6648 | vk_pipeline to_q8_1 = nullptr; |
| 6649 | |
| 6650 | if (dmmv == nullptr) { |
| 6651 | // Fall back to f16 dequant mul mat |
| 6652 | dmmv = ggml_vk_get_dequantize_mul_mat_vec(ctx, a_type: src0->type, b_type: src1->type, num_cols: ne11, m: ne20, k: ne00); |
| 6653 | quantize_y = false; |
| 6654 | } |
| 6655 | |
| 6656 | if (quantize_y) { |
| 6657 | to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, type: GGML_TYPE_Q8_1, use_x4_blocks: true); |
| 6658 | } |
| 6659 | |
| 6660 | const bool qx_needs_dequant = x_non_contig; |
| 6661 | const bool qy_needs_dequant = !quantize_y && ((src1->type != GGML_TYPE_F16 && !f16_f32_kernel) || y_non_contig); |
| 6662 | |
| 6663 | // Not implemented |
| 6664 | GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT |
| 6665 | |
| 6666 | GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT |
| 6667 | GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT |
| 6668 | GGML_ASSERT(dmmv != nullptr); |
| 6669 | |
| 6670 | const uint64_t x_ne = ne01 * ne00; |
| 6671 | const uint64_t y_ne = ne11 * ne10; |
| 6672 | const uint64_t d_ne = ne11 * ne01; |
| 6673 | |
| 6674 | const uint64_t qx_sz = ggml_vk_align_size(width: ggml_type_size(type: src0->type) * x_ne / ggml_blck_size(type: src0->type), align: ctx->device->properties.limits.minStorageBufferOffsetAlignment); |
| 6675 | const uint64_t qy_sz = ggml_type_size(type: src1->type) * y_ne / ggml_blck_size(type: src1->type); |
| 6676 | const uint64_t x_sz = x_non_contig ? ggml_vk_align_size(width: ggml_type_size(type: src0->type) * x_ne, align: ctx->device->properties.limits.minStorageBufferOffsetAlignment) : qx_sz; |
| 6677 | const uint64_t y_sz = quantize_y ? (y_ne * ggml_type_size(type: GGML_TYPE_Q8_1) / ggml_blck_size(type: GGML_TYPE_Q8_1)) : (f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne); |
| 6678 | const uint64_t d_sz = sizeof(float) * d_ne; |
| 6679 | |
| 6680 | { |
| 6681 | const uint64_t x_sz_upd = x_sz * ne02 * ne03; |
| 6682 | uint64_t y_sz_upd = y_sz * ne12 * ne13; |
| 6683 | if (quantize_y) { |
| 6684 | y_sz_upd = CEIL_DIV(y_sz_upd, 144) * 144; |
| 6685 | } |
| 6686 | if ( |
| 6687 | (qx_needs_dequant && x_sz_upd > ctx->device->properties.limits.maxStorageBufferRange) || |
| 6688 | (qy_needs_dequant && y_sz_upd > ctx->device->properties.limits.maxStorageBufferRange)) { |
| 6689 | GGML_ABORT("Requested preallocation size is too large" ); |
| 6690 | } |
| 6691 | if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) { |
| 6692 | ctx->prealloc_size_x = x_sz_upd; |
| 6693 | ggml_vk_preallocate_buffers(ctx, subctx); |
| 6694 | } |
| 6695 | if ((qy_needs_dequant || quantize_y) && ctx->prealloc_size_y < y_sz_upd) { |
| 6696 | ctx->prealloc_size_y = y_sz_upd; |
| 6697 | ggml_vk_preallocate_buffers(ctx, subctx); |
| 6698 | } |
| 6699 | |
| 6700 | // Request descriptor sets |
| 6701 | if (qx_needs_dequant) { |
| 6702 | ggml_pipeline_request_descriptor_sets(ctx, pipeline&: to_fp16_vk_0, n: 1); |
| 6703 | } |
| 6704 | if (qy_needs_dequant) { |
| 6705 | ggml_pipeline_request_descriptor_sets(ctx, pipeline&: to_fp16_vk_1, n: 1); |
| 6706 | } |
| 6707 | if (quantize_y) { |
| 6708 | ggml_pipeline_request_descriptor_sets(ctx, pipeline&: to_q8_1, n: 1); |
| 6709 | } |
| 6710 | ggml_pipeline_request_descriptor_sets(ctx, pipeline&: dmmv, n: 1); |
| 6711 | } |
| 6712 | |
| 6713 | vk_buffer d_D; |
| 6714 | uint64_t d_buf_offset = 0; |
| 6715 | |
| 6716 | if (ctx->num_additional_fused_ops > 0) { |
| 6717 | const ggml_tensor * add = cgraph->nodes[node_idx + 1]; |
| 6718 | ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)add->buffer->context; |
| 6719 | d_D = dst_buf_ctx->dev_buffer; |
| 6720 | d_buf_offset = vk_tensor_offset(tensor: add) + add->view_offs; |
| 6721 | } else { |
| 6722 | ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; |
| 6723 | d_D = dst_buf_ctx->dev_buffer; |
| 6724 | d_buf_offset = vk_tensor_offset(tensor: dst) + dst->view_offs; |
| 6725 | } |
| 6726 | |
| 6727 | GGML_ASSERT(d_D != nullptr); |
| 6728 | vk_buffer d_X; |
| 6729 | uint64_t x_buf_offset = 0; |
| 6730 | vk_buffer d_Y; |
| 6731 | uint64_t y_buf_offset = 0; |
| 6732 | if(!src0_uma) { |
| 6733 | d_Qx = src0_buf_ctx->dev_buffer; |
| 6734 | qx_buf_offset = vk_tensor_offset(tensor: src0) + src0->view_offs; |
| 6735 | GGML_ASSERT(d_Qx != nullptr); |
| 6736 | } |
| 6737 | if(!src1_uma) { |
| 6738 | d_Qy = src1_buf_ctx->dev_buffer; |
| 6739 | qy_buf_offset = vk_tensor_offset(tensor: src1) + src1->view_offs; |
| 6740 | GGML_ASSERT(d_Qy != nullptr); |
| 6741 | } |
| 6742 | if (qx_needs_dequant) { |
| 6743 | d_X = ctx->prealloc_x; |
| 6744 | } else { |
| 6745 | d_X = d_Qx; |
| 6746 | x_buf_offset = qx_buf_offset; |
| 6747 | GGML_ASSERT(qx_sz == x_sz); |
| 6748 | } |
| 6749 | if (qy_needs_dequant) { |
| 6750 | d_Y = ctx->prealloc_y; |
| 6751 | } else if (quantize_y) { |
| 6752 | d_Y = ctx->prealloc_y; |
| 6753 | GGML_ASSERT(d_Y->size >= CEIL_DIV(y_sz * ne12 * ne13, 144) * 144); |
| 6754 | } else { |
| 6755 | d_Y = d_Qy; |
| 6756 | y_buf_offset = qy_buf_offset; |
| 6757 | GGML_ASSERT(qy_sz == y_sz); |
| 6758 | } |
| 6759 | |
| 6760 | if (x_non_contig) { |
| 6761 | if (ctx->prealloc_x_need_sync) { |
| 6762 | ggml_vk_sync_buffers(ctx, subctx); |
| 6763 | } |
| 6764 | |
| 6765 | GGML_ASSERT(x_sz == ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment)); |
| 6766 | ggml_vk_cpy_to_contiguous(ctx, subctx, pipeline: to_fp16_vk_0, tensor: src0, in: ggml_vk_subbuffer(ctx, buf: d_Qx, offset: qx_buf_offset), out: ggml_vk_subbuffer(ctx, buf: d_X, offset: 0)); |
| 6767 | } |
| 6768 | if (y_non_contig) { |
| 6769 | GGML_ASSERT(y_sz == ggml_type_size(src1->type) * y_ne); |
| 6770 | if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() || |
| 6771 | ctx->prealloc_y_last_tensor_used != src1) { |
| 6772 | if (ctx->prealloc_y_need_sync) { |
| 6773 | ggml_vk_sync_buffers(ctx, subctx); |
| 6774 | } |
| 6775 | ggml_vk_cpy_to_contiguous(ctx, subctx, pipeline: to_fp16_vk_1, tensor: src1, in: ggml_vk_subbuffer(ctx, buf: d_Qy, offset: qy_buf_offset), out: ggml_vk_subbuffer(ctx, buf: d_Y, offset: 0)); |
| 6776 | ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get(); |
| 6777 | ctx->prealloc_y_last_tensor_used = src1; |
| 6778 | } |
| 6779 | } |
| 6780 | if (quantize_y) { |
| 6781 | if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() || |
| 6782 | ctx->prealloc_y_last_tensor_used != src1) { |
| 6783 | if (ctx->prealloc_y_need_sync) { |
| 6784 | ggml_vk_sync_buffers(ctx, subctx); |
| 6785 | } |
| 6786 | ggml_vk_quantize_q8_1(ctx, subctx, in: ggml_vk_subbuffer(ctx, buf: d_Qy, offset: qy_buf_offset), out: ggml_vk_subbuffer(ctx, buf: d_Y, offset: 0), ne: y_ne * ne12 * ne13, use_x4_blocks: true); |
| 6787 | ctx->prealloc_y_last_pipeline_used = to_q8_1.get(); |
| 6788 | ctx->prealloc_y_last_tensor_used = src1; |
| 6789 | } |
| 6790 | } |
| 6791 | |
| 6792 | // For batch_n, the A matrix is the same for each batch, and B/D use the row stride as the batch stride |
| 6793 | uint32_t stride_batch_x = batch_n ? 0 : ne00*ne01; |
| 6794 | uint32_t stride_batch_y = batch_n ? ne10 : (ne10*ne11); |
| 6795 | uint32_t stride_batch_d = batch_n ? ne20 : (ne20*ne21); |
| 6796 | |
| 6797 | if (!ggml_vk_dim01_contiguous(tensor: src0) && !qx_needs_dequant) { |
| 6798 | stride_batch_x = src0->nb[0] / ggml_type_size(type: src0->type); |
| 6799 | } |
| 6800 | |
| 6801 | if (!ggml_vk_dim01_contiguous(tensor: src1) && !qy_needs_dequant) { |
| 6802 | stride_batch_y = src1->nb[0] / ggml_type_size(type: src1->type); |
| 6803 | } |
| 6804 | |
| 6805 | const uint32_t max_groups_x = ctx->device->properties.limits.maxComputeWorkGroupCount[0]; |
| 6806 | |
| 6807 | uint32_t groups_x = ne01; |
| 6808 | uint32_t groups_z = 1; |
| 6809 | |
| 6810 | if (ne01 > max_groups_x) { |
| 6811 | groups_z = 64; |
| 6812 | groups_x = CEIL_DIV(groups_x, groups_z); |
| 6813 | } |
| 6814 | |
| 6815 | // TODO: Clean up this whole sz * ne_2 * ne_3 thing, it hasn't been necessary for a long time |
| 6816 | uint32_t y_sz_total = y_sz * ne12 * ne13; |
| 6817 | if (quantize_y) { |
| 6818 | y_sz_total = CEIL_DIV(y_sz_total, 144) * 144; |
| 6819 | } |
| 6820 | |
| 6821 | uint32_t enable_bias = ctx->num_additional_fused_ops > 0; |
| 6822 | |
| 6823 | vk_buffer d_B = d_D; |
| 6824 | size_t b_buf_offset = 0; |
| 6825 | uint64_t b_sz = 0; |
| 6826 | |
| 6827 | if (enable_bias) { |
| 6828 | const ggml_tensor * add = cgraph->nodes[node_idx + 1]; |
| 6829 | const ggml_tensor * bias = add->src[0] == dst ? add->src[1] : add->src[0]; |
| 6830 | |
| 6831 | bool b_uma = false; |
| 6832 | if (ctx->device->uma) { |
| 6833 | ggml_vk_host_get(device: ctx->device, ptr: bias->data, buf&: d_B, buf_offset&: b_buf_offset); |
| 6834 | b_uma = d_B != nullptr; |
| 6835 | } |
| 6836 | if(!b_uma) { |
| 6837 | ggml_backend_vk_buffer_context * bias_buf_ctx = (ggml_backend_vk_buffer_context *)bias->buffer->context; |
| 6838 | d_B = bias_buf_ctx->dev_buffer; |
| 6839 | b_buf_offset = vk_tensor_offset(tensor: bias) + bias->view_offs; |
| 6840 | GGML_ASSERT(d_B != nullptr); |
| 6841 | b_sz = ggml_nbytes(tensor: bias); |
| 6842 | } |
| 6843 | } |
| 6844 | |
| 6845 | // compute |
| 6846 | const vk_mat_vec_push_constants pc = { |
| 6847 | .ncols: (uint32_t)ne00, .stride_a: (uint32_t)ne10, .stride_b: (uint32_t)ne10, .stride_d: (uint32_t)ne01, |
| 6848 | .batch_stride_a: stride_batch_x, .batch_stride_b: stride_batch_y, .batch_stride_d: stride_batch_d, .enable_bias: enable_bias, |
| 6849 | .ne02: (uint32_t)ne02, .ne12: (uint32_t)ne12, .broadcast2: (uint32_t)r2, .broadcast3: (uint32_t)r3, |
| 6850 | }; |
| 6851 | ggml_vk_dispatch_pipeline(ctx, subctx, pipeline&: dmmv, |
| 6852 | descriptor_buffer_infos: { |
| 6853 | vk_subbuffer{ .buffer: d_X, .offset: x_buf_offset, .size: x_sz * ne02 * ne03 }, |
| 6854 | vk_subbuffer{ .buffer: d_Y, .offset: y_buf_offset, .size: y_sz_total }, |
| 6855 | vk_subbuffer{ .buffer: d_D, .offset: d_buf_offset, .size: d_sz * ne22 * ne23}, |
| 6856 | vk_subbuffer{ .buffer: d_B, .offset: b_buf_offset, .size: b_sz }, |
| 6857 | }, |
| 6858 | push_constants: pc, elements: { groups_x, (uint32_t)(ne12 * ne13), groups_z }); |
| 6859 | |
| 6860 | if (x_non_contig) { |
| 6861 | ctx->prealloc_x_need_sync = true; |
| 6862 | } |
| 6863 | if (y_non_contig || quantize_y) { |
| 6864 | ctx->prealloc_y_need_sync = true; |
| 6865 | } |
| 6866 | } |
| 6867 | |
| 6868 | static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) { |
| 6869 | ggml_tensor * dst = cgraph->nodes[node_idx]; |
| 6870 | const ggml_tensor * src0 = dst->src[0]; |
| 6871 | const ggml_tensor * src1 = dst->src[1]; |
| 6872 | VK_LOG_DEBUG("ggml_vk_mul_mat_p021_f16_f32(" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; |
| 6873 | std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; |
| 6874 | std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; |
| 6875 | std::cerr << "))" ); |
| 6876 | GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1)); |
| 6877 | GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]); // NOLINT |
| 6878 | GGML_ASSERT(src1->nb[0] <= src1->nb[1] && src1->nb[2] <= src1->nb[3]); // NOLINT |
| 6879 | GGML_ASSERT(src0->type == GGML_TYPE_F16); |
| 6880 | GGML_ASSERT(src1->type == GGML_TYPE_F32); |
| 6881 | |
| 6882 | const uint64_t ne00 = src0->ne[0]; |
| 6883 | const uint64_t ne01 = src0->ne[1]; |
| 6884 | const uint64_t ne02 = src0->ne[2]; |
| 6885 | // const uint64_t ne03 = src0->ne[3]; |
| 6886 | |
| 6887 | const uint64_t ne10 = src1->ne[0]; |
| 6888 | const uint64_t ne11 = src1->ne[1]; |
| 6889 | const uint64_t ne12 = src1->ne[2]; |
| 6890 | // const uint64_t ne13 = src1->ne[3]; |
| 6891 | |
| 6892 | GGML_ASSERT(ne11 == 1); |
| 6893 | |
| 6894 | ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; |
| 6895 | ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context; |
| 6896 | |
| 6897 | vk_buffer d_Qy = nullptr; |
| 6898 | size_t qy_buf_offset = 0; |
| 6899 | |
| 6900 | bool src1_uma = false; |
| 6901 | |
| 6902 | if (ctx->device->uma) { |
| 6903 | ggml_vk_host_get(device: ctx->device, ptr: src1->data, buf&: d_Qy, buf_offset&: qy_buf_offset); |
| 6904 | src1_uma = d_Qy != nullptr; |
| 6905 | } |
| 6906 | |
| 6907 | const uint64_t x_ne = ne00 * ne01 * ne02; |
| 6908 | const uint64_t y_ne = ne10 * ne11 * ne12; |
| 6909 | const uint64_t d_ne = ne01 * ne11 * ne12; |
| 6910 | |
| 6911 | const uint64_t qx_sz = ggml_vk_align_size(width: ggml_type_size(type: src0->type) * x_ne / ggml_blck_size(type: src0->type), align: ctx->device->properties.limits.minStorageBufferOffsetAlignment); |
| 6912 | const uint64_t qy_sz = ggml_type_size(type: src1->type) * y_ne / ggml_blck_size(type: src1->type); |
| 6913 | const uint64_t d_sz = sizeof(float) * d_ne; |
| 6914 | |
| 6915 | // With grouped query attention there are > 1 Q matrices per K, V matrix. |
| 6916 | uint32_t gqa_ratio = (uint32_t)ne12 / (uint32_t)ne02; |
| 6917 | if (gqa_ratio > 8 || gqa_ratio == 0 || ne12 != ne02 * gqa_ratio) { |
| 6918 | gqa_ratio = 1; |
| 6919 | } |
| 6920 | |
| 6921 | { |
| 6922 | // Request descriptor sets |
| 6923 | ggml_pipeline_request_descriptor_sets(ctx, pipeline&: ctx->device->pipeline_mul_mat_vec_p021_f16_f32[gqa_ratio - 1], n: 1); |
| 6924 | } |
| 6925 | |
| 6926 | vk_buffer d_D; |
| 6927 | uint64_t d_buf_offset = 0; |
| 6928 | |
| 6929 | if (ctx->num_additional_fused_ops > 0) { |
| 6930 | const ggml_tensor * add = cgraph->nodes[node_idx + 1]; |
| 6931 | ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)add->buffer->context; |
| 6932 | d_D = dst_buf_ctx->dev_buffer; |
| 6933 | d_buf_offset = vk_tensor_offset(tensor: add) + add->view_offs; |
| 6934 | } else { |
| 6935 | ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; |
| 6936 | d_D = dst_buf_ctx->dev_buffer; |
| 6937 | d_buf_offset = vk_tensor_offset(tensor: dst) + dst->view_offs; |
| 6938 | } |
| 6939 | GGML_ASSERT(d_D != nullptr); |
| 6940 | vk_buffer d_Qx = src0_buf_ctx->dev_buffer; |
| 6941 | const uint64_t qx_buf_offset = vk_tensor_offset(tensor: src0) + src0->view_offs; |
| 6942 | GGML_ASSERT(d_Qx != nullptr); |
| 6943 | if (!src1_uma) { |
| 6944 | d_Qy = src1_buf_ctx->dev_buffer; |
| 6945 | qy_buf_offset = vk_tensor_offset(tensor: src1) + src1->view_offs; |
| 6946 | GGML_ASSERT(d_Qx != nullptr); |
| 6947 | } |
| 6948 | |
| 6949 | const uint64_t qy_buffer_offset = (qy_buf_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment; |
| 6950 | const uint64_t qy_shader_offset = qy_buf_offset - qy_buffer_offset; |
| 6951 | |
| 6952 | const uint64_t d_buffer_offset = (d_buf_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment; |
| 6953 | const uint64_t d_shader_offset = d_buf_offset - d_buffer_offset; |
| 6954 | |
| 6955 | uint32_t enable_bias = ctx->num_additional_fused_ops > 0; |
| 6956 | |
| 6957 | vk_buffer d_B = d_D; |
| 6958 | size_t b_buf_offset = 0; |
| 6959 | uint64_t b_sz = 0; |
| 6960 | |
| 6961 | if (enable_bias) { |
| 6962 | const ggml_tensor * add = cgraph->nodes[node_idx + 1]; |
| 6963 | const ggml_tensor * bias = add->src[0] == dst ? add->src[1] : add->src[0]; |
| 6964 | |
| 6965 | bool b_uma = false; |
| 6966 | if (ctx->device->uma) { |
| 6967 | ggml_vk_host_get(device: ctx->device, ptr: bias->data, buf&: d_B, buf_offset&: b_buf_offset); |
| 6968 | b_uma = d_B != nullptr; |
| 6969 | } |
| 6970 | if(!b_uma) { |
| 6971 | ggml_backend_vk_buffer_context * bias_buf_ctx = (ggml_backend_vk_buffer_context *)bias->buffer->context; |
| 6972 | d_B = bias_buf_ctx->dev_buffer; |
| 6973 | b_buf_offset = vk_tensor_offset(tensor: bias) + bias->view_offs; |
| 6974 | GGML_ASSERT(d_B != nullptr); |
| 6975 | b_sz = ggml_nbytes(tensor: bias); |
| 6976 | } |
| 6977 | } |
| 6978 | |
| 6979 | // compute |
| 6980 | const std::array<uint32_t, 7> pc = { (uint32_t)ne00, (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne12, (uint32_t)(qy_shader_offset / ggml_type_size(type: src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(type: dst->type)), enable_bias }; |
| 6981 | |
| 6982 | uint32_t workgroups_z = (uint32_t)ne12; |
| 6983 | // When gqa_ratio > 1, each invocation does multiple rows and we can launch fewer workgroups |
| 6984 | if (gqa_ratio > 1) { |
| 6985 | workgroups_z /= gqa_ratio; |
| 6986 | } |
| 6987 | |
| 6988 | ggml_vk_dispatch_pipeline(ctx, subctx, pipeline&: ctx->device->pipeline_mul_mat_vec_p021_f16_f32[gqa_ratio - 1], |
| 6989 | descriptor_buffer_infos: { |
| 6990 | vk_subbuffer{ .buffer: d_Qx, .offset: qx_buf_offset, .size: qx_sz }, |
| 6991 | vk_subbuffer{ .buffer: d_Qy, .offset: qy_buffer_offset, .size: qy_sz + qy_shader_offset }, |
| 6992 | vk_subbuffer{ .buffer: d_D, .offset: d_buffer_offset, .size: d_sz + d_shader_offset }, |
| 6993 | vk_subbuffer{ .buffer: d_B, .offset: b_buf_offset, .size: b_sz }, |
| 6994 | }, push_constants: pc, elements: { 1, (uint32_t)ne01, workgroups_z }); |
| 6995 | } |
| 6996 | |
| 6997 | static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) { |
| 6998 | ggml_tensor * dst = cgraph->nodes[node_idx]; |
| 6999 | const ggml_tensor * src0 = dst->src[0]; |
| 7000 | const ggml_tensor * src1 = dst->src[1]; |
| 7001 | VK_LOG_DEBUG("ggml_vk_mul_mat_nc_f16_f32((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; |
| 7002 | std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; |
| 7003 | std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; |
| 7004 | std::cerr << "))" ); |
| 7005 | GGML_ASSERT(!ggml_is_transposed(src0)); |
| 7006 | GGML_ASSERT(!ggml_is_transposed(src1)); |
| 7007 | GGML_ASSERT(!ggml_is_permuted(src0)); |
| 7008 | GGML_ASSERT(src0->type == GGML_TYPE_F16); |
| 7009 | GGML_ASSERT(src1->type == GGML_TYPE_F32); |
| 7010 | |
| 7011 | const uint64_t ne00 = src0->ne[0]; |
| 7012 | const uint64_t ne01 = src0->ne[1]; |
| 7013 | const uint64_t ne02 = src0->ne[2]; |
| 7014 | const uint64_t ne03 = src0->ne[3]; |
| 7015 | |
| 7016 | const uint64_t nb01 = src0->nb[1]; |
| 7017 | const uint64_t nb02 = src0->nb[2]; |
| 7018 | |
| 7019 | const uint64_t nb12 = src1->nb[2]; |
| 7020 | |
| 7021 | // const uint64_t ne10 = src1->ne[0]; |
| 7022 | const uint64_t ne11 = src1->ne[1]; |
| 7023 | const uint64_t ne12 = src1->ne[2]; |
| 7024 | // const uint64_t ne13 = src1->ne[3]; |
| 7025 | |
| 7026 | const uint32_t nb03 = (uint32_t)(src0->nb[3] / sizeof(ggml_fp16_t)); |
| 7027 | const uint32_t nb13 = (uint32_t)(src1->nb[3] / sizeof(float)); |
| 7028 | const uint32_t nb23 = (uint32_t)(dst->nb[3] / sizeof(float)); |
| 7029 | |
| 7030 | GGML_ASSERT(ne11 == 1); |
| 7031 | GGML_ASSERT(src0->ne[3] == src1->ne[3]); // checked in supports_op |
| 7032 | |
| 7033 | ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; |
| 7034 | ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context; |
| 7035 | |
| 7036 | vk_buffer d_Qy = nullptr; |
| 7037 | size_t qy_buf_offset = 0; |
| 7038 | |
| 7039 | bool src1_uma = false; |
| 7040 | |
| 7041 | if (ctx->device->uma) { |
| 7042 | ggml_vk_host_get(device: ctx->device, ptr: src1->data, buf&: d_Qy, buf_offset&: qy_buf_offset); |
| 7043 | src1_uma = d_Qy != nullptr; |
| 7044 | } |
| 7045 | |
| 7046 | const uint64_t d_ne = ne01 * ne11 * ne12 * ne03; |
| 7047 | |
| 7048 | const uint32_t row_stride_x = nb01 / sizeof(ggml_fp16_t); |
| 7049 | const uint32_t channel_stride_x = nb02 / sizeof(ggml_fp16_t); |
| 7050 | const uint32_t channel_stride_y = nb12 / sizeof(float); |
| 7051 | |
| 7052 | const uint64_t qx_sz = ggml_nbytes(tensor: src0); |
| 7053 | const uint64_t qy_sz = ggml_nbytes(tensor: src1); |
| 7054 | const uint64_t d_sz = sizeof(float) * d_ne; |
| 7055 | |
| 7056 | { |
| 7057 | // Request descriptor sets |
| 7058 | ggml_pipeline_request_descriptor_sets(ctx, pipeline&: ctx->device->pipeline_mul_mat_vec_nc_f16_f32, n: 1); |
| 7059 | } |
| 7060 | |
| 7061 | vk_buffer d_D; |
| 7062 | uint64_t d_buf_offset = 0; |
| 7063 | |
| 7064 | if (ctx->num_additional_fused_ops > 0) { |
| 7065 | const ggml_tensor * add = cgraph->nodes[node_idx + 1]; |
| 7066 | ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)add->buffer->context; |
| 7067 | d_D = dst_buf_ctx->dev_buffer; |
| 7068 | d_buf_offset = vk_tensor_offset(tensor: add) + add->view_offs; |
| 7069 | } else { |
| 7070 | ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; |
| 7071 | d_D = dst_buf_ctx->dev_buffer; |
| 7072 | d_buf_offset = vk_tensor_offset(tensor: dst) + dst->view_offs; |
| 7073 | } |
| 7074 | |
| 7075 | GGML_ASSERT(d_D != nullptr); |
| 7076 | vk_buffer d_Qx = src0_buf_ctx->dev_buffer; |
| 7077 | const uint64_t qx_buf_offset = vk_tensor_offset(tensor: src0) + src0->view_offs; |
| 7078 | GGML_ASSERT(d_Qx != nullptr); |
| 7079 | if (!src1_uma) { |
| 7080 | d_Qy = src1_buf_ctx->dev_buffer; |
| 7081 | qy_buf_offset = vk_tensor_offset(tensor: src1) + src1->view_offs; |
| 7082 | GGML_ASSERT(d_Qx != nullptr); |
| 7083 | } |
| 7084 | |
| 7085 | const uint64_t qy_buffer_offset = (qy_buf_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment; |
| 7086 | const uint64_t qy_shader_offset = qy_buf_offset - qy_buffer_offset; |
| 7087 | |
| 7088 | const uint64_t d_buffer_offset = (d_buf_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment; |
| 7089 | const uint64_t d_shader_offset = d_buf_offset - d_buffer_offset; |
| 7090 | |
| 7091 | uint32_t enable_bias = ctx->num_additional_fused_ops > 0; |
| 7092 | |
| 7093 | vk_buffer d_B = d_D; |
| 7094 | size_t b_buf_offset = 0; |
| 7095 | uint64_t b_sz = 0; |
| 7096 | |
| 7097 | if (enable_bias) { |
| 7098 | const ggml_tensor * add = cgraph->nodes[node_idx + 1]; |
| 7099 | const ggml_tensor * bias = add->src[0] == dst ? add->src[1] : add->src[0]; |
| 7100 | |
| 7101 | bool b_uma = false; |
| 7102 | if (ctx->device->uma) { |
| 7103 | ggml_vk_host_get(device: ctx->device, ptr: bias->data, buf&: d_B, buf_offset&: b_buf_offset); |
| 7104 | b_uma = d_B != nullptr; |
| 7105 | } |
| 7106 | if(!b_uma) { |
| 7107 | ggml_backend_vk_buffer_context * bias_buf_ctx = (ggml_backend_vk_buffer_context *)bias->buffer->context; |
| 7108 | d_B = bias_buf_ctx->dev_buffer; |
| 7109 | b_buf_offset = vk_tensor_offset(tensor: bias) + bias->view_offs; |
| 7110 | GGML_ASSERT(d_B != nullptr); |
| 7111 | b_sz = ggml_nbytes(tensor: bias); |
| 7112 | } |
| 7113 | } |
| 7114 | |
| 7115 | // compute |
| 7116 | const std::array<uint32_t, 13> pc = { (uint32_t)ne00, (uint32_t)ne01, row_stride_x, channel_stride_x, channel_stride_y, (uint32_t)(ne12 / ne02), (uint32_t)ne12, (uint32_t)(qy_shader_offset / ggml_type_size(type: src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(type: dst->type)), nb03, nb13, nb23, enable_bias }; |
| 7117 | ggml_vk_dispatch_pipeline(ctx, subctx, pipeline&: ctx->device->pipeline_mul_mat_vec_nc_f16_f32, |
| 7118 | descriptor_buffer_infos: { |
| 7119 | vk_subbuffer{ .buffer: d_Qx, .offset: qx_buf_offset, .size: qx_sz }, |
| 7120 | vk_subbuffer{ .buffer: d_Qy, .offset: qy_buffer_offset, .size: qy_sz + qy_shader_offset }, |
| 7121 | vk_subbuffer{ .buffer: d_D, .offset: d_buffer_offset, .size: d_sz + d_shader_offset }, |
| 7122 | vk_subbuffer{ .buffer: d_B, .offset: b_buf_offset, .size: b_sz }, |
| 7123 | }, push_constants: pc, elements: { (uint32_t)ne03, (uint32_t)ne01, (uint32_t)ne12 }); |
| 7124 | } |
| 7125 | |
| 7126 | static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) { |
| 7127 | ggml_tensor * dst = cgraph->nodes[node_idx]; |
| 7128 | ggml_tensor * src0 = dst->src[0]; |
| 7129 | ggml_tensor * src1 = dst->src[1]; |
| 7130 | VK_LOG_DEBUG("ggml_vk_mul_mat(" << src0 << ", " << src1 << ", " << dst << ")" ); |
| 7131 | |
| 7132 | // Handle huge A matrix by splitting the M dimensions. This works well for convolution use cases |
| 7133 | // where the M dimension is very large. |
| 7134 | // Split_k doesn't work with M splitting. |
| 7135 | const size_t nbytes = ggml_nbytes(tensor: src0); |
| 7136 | const bool needs_split = nbytes > ctx->device->properties.limits.maxStorageBufferRange; |
| 7137 | if (needs_split) { |
| 7138 | // Choose the number of rows that can fit (and divide by two, to allow for any additional offsets) |
| 7139 | const uint32_t M_split = ctx->device->properties.limits.maxStorageBufferRange / (2 * src0->nb[1]); |
| 7140 | uint32_t m_offset = 0; |
| 7141 | while (m_offset < dst->ne[0]) { |
| 7142 | const uint32_t cur_M_size = std::min(a: M_split, b: (uint32_t)(dst->ne[0] - m_offset)); |
| 7143 | ggml_tensor dst2 = *dst; |
| 7144 | ggml_tensor src02 = *src0; |
| 7145 | |
| 7146 | dst2.view_src = dst->view_src ? dst->view_src : dst; |
| 7147 | src02.view_src = src0->view_src ? src0->view_src : src0; |
| 7148 | |
| 7149 | dst2.view_offs += m_offset * dst->nb[0]; |
| 7150 | src02.view_offs += m_offset * src0->nb[1]; |
| 7151 | dst2.ne[0] = cur_M_size; |
| 7152 | src02.ne[1] = cur_M_size; |
| 7153 | |
| 7154 | ggml_vk_mul_mat_q_f16(ctx, subctx, src0: &src02, src1, dst: &dst2, disable_split_k: true); |
| 7155 | |
| 7156 | m_offset += cur_M_size; |
| 7157 | } |
| 7158 | } else if (src0->type == GGML_TYPE_F16 && ggml_is_permuted(tensor: src0) && ggml_is_permuted(tensor: src1) && dst->ne[1] == 1 && |
| 7159 | // detect 0213 permutation, and batch size of 1 |
| 7160 | src0->nb[0] <= src0->nb[2] && |
| 7161 | src0->nb[2] <= src0->nb[1] && |
| 7162 | src0->nb[1] <= src0->nb[3] && |
| 7163 | src1->nb[0] <= src1->nb[2] && |
| 7164 | src1->nb[2] <= src1->nb[1] && |
| 7165 | src1->nb[1] <= src1->nb[3] && |
| 7166 | src0->ne[3] == 1 && |
| 7167 | src1->ne[3] == 1) { |
| 7168 | ggml_vk_mul_mat_vec_p021_f16_f32(ctx, subctx, cgraph, node_idx); |
| 7169 | } else if (src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(tensor: src0) && !ggml_is_transposed(tensor: src1) && dst->ne[1] == 1 && |
| 7170 | !ggml_is_permuted(tensor: src0) && !ggml_is_permuted(tensor: src1)) { |
| 7171 | ggml_vk_mul_mat_vec_nc_f16_f32(ctx, subctx, cgraph, node_idx); |
| 7172 | // mul_mat_vec supports batching ne12*ne13 when ne11==1, or treating ne11 as the batch size (up to four) |
| 7173 | // when ne12 and ne13 are one. |
| 7174 | } else if ((dst->ne[1] == 1 || (dst->ne[1] <= mul_mat_vec_max_cols && src1->ne[2] * src1->ne[3] == 1)) && |
| 7175 | (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16 || ggml_is_quantized(type: src0->type))) { |
| 7176 | ggml_vk_mul_mat_vec_q_f16(ctx, subctx, cgraph, node_idx); |
| 7177 | } else { |
| 7178 | ggml_vk_mul_mat_q_f16(ctx, subctx, src0, src1, dst, disable_split_k: false); |
| 7179 | } |
| 7180 | } |
| 7181 | |
| 7182 | static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) { |
| 7183 | VK_LOG_DEBUG("ggml_vk_mul_mat_id_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; |
| 7184 | std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; |
| 7185 | std::cerr << "), (" << ids << ", name=" << ids->name << ", type=" << ids->type << ", ne0=" << ids->ne[0] << ", ne1=" << ids->ne[1] << ", ne2=" << ids->ne[2] << ", ne3=" << ids->ne[3] << ", nb0=" << ids->nb[0] << ", nb1=" << ids->nb[1] << ", nb2=" << ids->nb[2] << ", nb3=" << ids->nb[3]; |
| 7186 | std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3] << "),)" ); |
| 7187 | GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT |
| 7188 | GGML_ASSERT(ids->type == GGML_TYPE_I32); |
| 7189 | |
| 7190 | const uint64_t ne00 = src0->ne[0]; |
| 7191 | const uint64_t ne01 = src0->ne[1]; |
| 7192 | const uint64_t ne02 = src0->ne[2]; |
| 7193 | const uint64_t ne03 = src0->ne[3]; |
| 7194 | |
| 7195 | const uint64_t ne10 = src1->ne[0]; |
| 7196 | const uint64_t ne11 = src1->ne[1]; |
| 7197 | const uint64_t ne12 = src1->ne[2]; |
| 7198 | const uint64_t ne13 = src1->ne[3]; |
| 7199 | |
| 7200 | const uint64_t nei0 = ids->ne[0]; |
| 7201 | const uint64_t nei1 = ids->ne[1]; |
| 7202 | |
| 7203 | const uint32_t nbi1 = ids->nb[1]; |
| 7204 | const uint32_t nbi2 = ids->nb[2]; |
| 7205 | |
| 7206 | const uint64_t ne20 = dst->ne[0]; |
| 7207 | const uint64_t ne21 = dst->ne[1]; |
| 7208 | const uint64_t ne22 = dst->ne[2]; |
| 7209 | const uint64_t ne23 = dst->ne[3]; |
| 7210 | |
| 7211 | const uint64_t n_as = ne02; |
| 7212 | |
| 7213 | ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; |
| 7214 | ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; |
| 7215 | ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context; |
| 7216 | ggml_backend_vk_buffer_context * ids_buf_ctx = (ggml_backend_vk_buffer_context *)ids->buffer->context; |
| 7217 | |
| 7218 | vk_buffer d_Qx = nullptr; |
| 7219 | size_t qx_buf_offset = 0; |
| 7220 | vk_buffer d_Qy = nullptr; |
| 7221 | size_t qy_buf_offset = 0; |
| 7222 | vk_buffer d_ids = nullptr; |
| 7223 | size_t ids_buf_offset = 0; |
| 7224 | |
| 7225 | bool src0_uma = false; |
| 7226 | bool src1_uma = false; |
| 7227 | bool ids_uma = false; |
| 7228 | |
| 7229 | if (ctx->device->uma) { |
| 7230 | ggml_vk_host_get(device: ctx->device, ptr: src0->data, buf&: d_Qx, buf_offset&: qx_buf_offset); |
| 7231 | ggml_vk_host_get(device: ctx->device, ptr: src1->data, buf&: d_Qy, buf_offset&: qy_buf_offset); |
| 7232 | ggml_vk_host_get(device: ctx->device, ptr: ids->data, buf&: d_ids, buf_offset&: ids_buf_offset); |
| 7233 | src0_uma = d_Qx != nullptr; |
| 7234 | src1_uma = d_Qy != nullptr; |
| 7235 | ids_uma = d_ids != nullptr; |
| 7236 | } |
| 7237 | |
| 7238 | // Reformat and convert to fp16 if non-contiguous, or for coopmat2 for better perf |
| 7239 | const bool x_non_contig = (ctx->device->coopmat2 && src0->type == GGML_TYPE_F32) || |
| 7240 | !ggml_vk_dim01_contiguous(tensor: src0); |
| 7241 | const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) || |
| 7242 | (src0->type == GGML_TYPE_BF16 && src1->type != GGML_TYPE_BF16) || |
| 7243 | !ggml_vk_dim01_contiguous(tensor: src1); |
| 7244 | |
| 7245 | // If src0 is BF16, try to use a BF16 x BF16 multiply |
| 7246 | ggml_type f16_type = src0->type == GGML_TYPE_BF16 ? GGML_TYPE_BF16 : GGML_TYPE_F16; |
| 7247 | |
| 7248 | const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig; |
| 7249 | |
| 7250 | bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(tensor: src1) && (ne11 * ne10) % 4 == 0; |
| 7251 | |
| 7252 | // Check for mmq first |
| 7253 | vk_matmul_pipeline mmp = quantize_y ? ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0_type: src0->type, src1_type: GGML_TYPE_Q8_1, prec: (ggml_prec)dst->op_params[0]) : nullptr; |
| 7254 | |
| 7255 | if (mmp == nullptr) { |
| 7256 | // Fall back to f16 dequant mul mat |
| 7257 | mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0_type: src0->type, src1_type: y_non_contig ? f16_type : src1->type, prec: (ggml_prec)dst->op_params[0]); |
| 7258 | quantize_y = false; |
| 7259 | } |
| 7260 | |
| 7261 | const bool qx_needs_dequant = mmp == nullptr || x_non_contig; |
| 7262 | const bool qy_needs_dequant = !quantize_y && ((src1->type != f16_type && !y_f32_kernel) || y_non_contig); |
| 7263 | |
| 7264 | if (qx_needs_dequant) { |
| 7265 | // Fall back to dequant + f16 mulmat |
| 7266 | mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0_type: f16_type, src1_type: y_f32_kernel ? GGML_TYPE_F32 : f16_type, prec: (ggml_prec)dst->op_params[0]); |
| 7267 | } |
| 7268 | |
| 7269 | // Not implemented |
| 7270 | GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT |
| 7271 | |
| 7272 | const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(width: ne10, align: ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, m: ne01, n: nei1, src0_type: qx_needs_dequant ? f16_type : src0->type)); |
| 7273 | const bool aligned = !quantize_y && ne10 == kpad && ne01 > 8 && nei1 > 8; |
| 7274 | |
| 7275 | vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, m: ne01, n: nei1, aligned, src0_type: qx_needs_dequant ? f16_type : src0->type); |
| 7276 | |
| 7277 | // Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking |
| 7278 | uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) :ne11; |
| 7279 | const uint64_t x_ne = ne01 * ne00; |
| 7280 | const uint64_t y_ne = padded_n * ne10; |
| 7281 | const uint64_t d_ne = ne21 * ne20; |
| 7282 | |
| 7283 | const uint64_t qx_sz = ggml_type_size(type: src0->type) * x_ne / ggml_blck_size(type: src0->type); |
| 7284 | const uint64_t qy_sz = ggml_type_size(type: src1->type) * y_ne / ggml_blck_size(type: src1->type); |
| 7285 | const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne; |
| 7286 | const uint64_t y_sz = quantize_y ? (y_ne * ggml_type_size(type: GGML_TYPE_Q8_1) / ggml_blck_size(type: GGML_TYPE_Q8_1)) : (y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne); |
| 7287 | const uint64_t ids_sz = nbi2; |
| 7288 | const uint64_t d_sz = sizeof(float) * d_ne; |
| 7289 | |
| 7290 | vk_pipeline to_fp16_vk_0 = nullptr; |
| 7291 | vk_pipeline to_fp16_vk_1 = nullptr; |
| 7292 | vk_pipeline to_q8_1 = nullptr; |
| 7293 | |
| 7294 | if (x_non_contig) { |
| 7295 | to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src: src0, dst: nullptr, to: f16_type); |
| 7296 | } else { |
| 7297 | to_fp16_vk_0 = ggml_vk_get_to_fp16(ctx, type: src0->type); |
| 7298 | } |
| 7299 | if (y_non_contig) { |
| 7300 | to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src: src1, dst: nullptr, to: f16_type); |
| 7301 | } else { |
| 7302 | to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, type: src1->type); |
| 7303 | } |
| 7304 | GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT |
| 7305 | GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT |
| 7306 | |
| 7307 | if (quantize_y) { |
| 7308 | to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, type: GGML_TYPE_Q8_1, use_x4_blocks: true); |
| 7309 | } |
| 7310 | |
| 7311 | { |
| 7312 | const uint64_t x_sz_upd = x_sz * ne02 * ne03; |
| 7313 | uint64_t y_sz_upd = y_sz * ne12 * ne13; |
| 7314 | if (quantize_y) { |
| 7315 | y_sz_upd = CEIL_DIV(y_sz_upd, 144) * 144; |
| 7316 | } |
| 7317 | if ( |
| 7318 | (qx_needs_dequant && x_sz_upd > ctx->device->properties.limits.maxStorageBufferRange) || |
| 7319 | (qy_needs_dequant && y_sz_upd > ctx->device->properties.limits.maxStorageBufferRange)) { |
| 7320 | GGML_ABORT("Requested preallocation size is too large" ); |
| 7321 | } |
| 7322 | if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) { |
| 7323 | ctx->prealloc_size_x = x_sz_upd; |
| 7324 | ggml_vk_preallocate_buffers(ctx, subctx); |
| 7325 | } |
| 7326 | if ((qy_needs_dequant || quantize_y) && ctx->prealloc_size_y < y_sz_upd) { |
| 7327 | ctx->prealloc_size_y = y_sz_upd; |
| 7328 | ggml_vk_preallocate_buffers(ctx, subctx); |
| 7329 | } |
| 7330 | |
| 7331 | // Request descriptor sets |
| 7332 | ggml_pipeline_request_descriptor_sets(ctx, pipeline, n: 1); |
| 7333 | if (qx_needs_dequant) { |
| 7334 | ggml_pipeline_request_descriptor_sets(ctx, pipeline&: to_fp16_vk_0, n: 1); |
| 7335 | } |
| 7336 | if (qy_needs_dequant) { |
| 7337 | ggml_pipeline_request_descriptor_sets(ctx, pipeline&: to_fp16_vk_1, n: 1); |
| 7338 | } |
| 7339 | if (quantize_y) { |
| 7340 | ggml_pipeline_request_descriptor_sets(ctx, pipeline&: to_q8_1, n: 1); |
| 7341 | } |
| 7342 | } |
| 7343 | |
| 7344 | vk_buffer d_D = dst_buf_ctx->dev_buffer; |
| 7345 | const uint64_t d_buf_offset = vk_tensor_offset(tensor: dst) + dst->view_offs; |
| 7346 | GGML_ASSERT(d_D != nullptr); |
| 7347 | vk_buffer d_X; |
| 7348 | uint64_t x_buf_offset = 0; |
| 7349 | vk_buffer d_Y; |
| 7350 | uint64_t y_buf_offset = 0; |
| 7351 | if (!src0_uma) { |
| 7352 | d_Qx = src0_buf_ctx->dev_buffer; |
| 7353 | qx_buf_offset = vk_tensor_offset(tensor: src0) + src0->view_offs; |
| 7354 | GGML_ASSERT(d_Qx != nullptr); |
| 7355 | } |
| 7356 | if (!src1_uma) { |
| 7357 | d_Qy = src1_buf_ctx->dev_buffer; |
| 7358 | qy_buf_offset = vk_tensor_offset(tensor: src1) + src1->view_offs; |
| 7359 | GGML_ASSERT(d_Qy != nullptr); |
| 7360 | } |
| 7361 | if (!ids_uma) { |
| 7362 | d_ids = ids_buf_ctx->dev_buffer; |
| 7363 | ids_buf_offset = vk_tensor_offset(tensor: ids) + ids->view_offs; |
| 7364 | GGML_ASSERT(d_ids != nullptr); |
| 7365 | } |
| 7366 | if (qx_needs_dequant) { |
| 7367 | d_X = ctx->prealloc_x; |
| 7368 | GGML_ASSERT(d_X->size >= x_sz * ne02 * ne03); |
| 7369 | } else { |
| 7370 | d_X = d_Qx; |
| 7371 | x_buf_offset = qx_buf_offset; |
| 7372 | GGML_ASSERT(qx_sz == x_sz); |
| 7373 | } |
| 7374 | if (qy_needs_dequant) { |
| 7375 | d_Y = ctx->prealloc_y; |
| 7376 | GGML_ASSERT(d_Y->size >= y_sz * ne12 * ne13); |
| 7377 | } else if (quantize_y) { |
| 7378 | d_Y = ctx->prealloc_y; |
| 7379 | GGML_ASSERT(d_Y->size >= CEIL_DIV(y_sz * ne12 * ne13, 144) * 144); |
| 7380 | } else { |
| 7381 | d_Y = d_Qy; |
| 7382 | y_buf_offset = qy_buf_offset; |
| 7383 | GGML_ASSERT(qy_sz == y_sz); |
| 7384 | } |
| 7385 | |
| 7386 | if (x_non_contig || qx_needs_dequant) { |
| 7387 | if (ctx->prealloc_x_need_sync) { |
| 7388 | ggml_vk_sync_buffers(ctx, subctx); |
| 7389 | } |
| 7390 | } |
| 7391 | |
| 7392 | if (x_non_contig) { |
| 7393 | ggml_vk_cpy_to_contiguous(ctx, subctx, pipeline: to_fp16_vk_0, tensor: src0, in: ggml_vk_subbuffer(ctx, buf: d_Qx, offset: qx_buf_offset), out: ggml_vk_subbuffer(ctx, buf: d_X, offset: 0)); |
| 7394 | } else if (qx_needs_dequant) { |
| 7395 | const std::vector<uint32_t> pc = { (uint32_t)ne01, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(ggml_nelements(tensor: src0)) }; |
| 7396 | ggml_vk_dispatch_pipeline(ctx, subctx, pipeline&: to_fp16_vk_0, |
| 7397 | descriptor_buffer_infos: { vk_subbuffer{ .buffer: d_Qx, .offset: qx_buf_offset, .size: qx_sz * ne02 * ne03 }, vk_subbuffer{ .buffer: d_X, .offset: 0, .size: x_sz * ne02 * ne03 } }, push_constants: pc, elements: { (uint32_t)(x_ne * ne02 * ne03), 1, 1}); |
| 7398 | ggml_vk_sync_buffers(ctx, subctx); |
| 7399 | } |
| 7400 | if (y_non_contig) { |
| 7401 | if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() || |
| 7402 | ctx->prealloc_y_last_tensor_used != src1) { |
| 7403 | if (ctx->prealloc_y_need_sync) { |
| 7404 | ggml_vk_sync_buffers(ctx, subctx); |
| 7405 | } |
| 7406 | ggml_vk_cpy_to_contiguous(ctx, subctx, pipeline: to_fp16_vk_1, tensor: src1, in: ggml_vk_subbuffer(ctx, buf: d_Qy, offset: qy_buf_offset), out: ggml_vk_subbuffer(ctx, buf: d_Y, offset: 0)); |
| 7407 | ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get(); |
| 7408 | ctx->prealloc_y_last_tensor_used = src1; |
| 7409 | } |
| 7410 | } |
| 7411 | if (quantize_y) { |
| 7412 | if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() || |
| 7413 | ctx->prealloc_y_last_tensor_used != src1) { |
| 7414 | if (ctx->prealloc_y_need_sync) { |
| 7415 | ggml_vk_sync_buffers(ctx, subctx); |
| 7416 | } |
| 7417 | ggml_vk_quantize_q8_1(ctx, subctx, in: ggml_vk_subbuffer(ctx, buf: d_Qy, offset: qy_buf_offset), out: ggml_vk_subbuffer(ctx, buf: d_Y, offset: 0), ne: y_ne * ne12 * ne13, use_x4_blocks: true); |
| 7418 | ctx->prealloc_y_last_pipeline_used = to_q8_1.get(); |
| 7419 | ctx->prealloc_y_last_tensor_used = src1; |
| 7420 | } |
| 7421 | } |
| 7422 | |
| 7423 | uint32_t stride_batch_x = ne00*ne01; |
| 7424 | uint32_t stride_batch_y = ne10*ne11; |
| 7425 | |
| 7426 | if (!ggml_vk_dim01_contiguous(tensor: src0) && !qx_needs_dequant) { |
| 7427 | stride_batch_x = src0->nb[0] / ggml_type_size(type: src0->type); |
| 7428 | } |
| 7429 | |
| 7430 | if (!ggml_vk_dim01_contiguous(tensor: src1) && !qy_needs_dequant && !quantize_y) { |
| 7431 | stride_batch_y = src1->nb[0] / ggml_type_size(type: src1->type); |
| 7432 | } |
| 7433 | |
| 7434 | uint32_t y_sz_total = y_sz * ne12 * ne13; |
| 7435 | if (quantize_y) { |
| 7436 | y_sz_total = CEIL_DIV(y_sz_total, 144) * 144; |
| 7437 | } |
| 7438 | |
| 7439 | // compute |
| 7440 | ggml_vk_matmul_id( |
| 7441 | ctx, subctx, pipeline, |
| 7442 | a: { .buffer: d_X, .offset: x_buf_offset, .size: x_sz * ne02 * ne03 }, b: { .buffer: d_Y, .offset: y_buf_offset, .size: y_sz_total }, |
| 7443 | d: { .buffer: d_D, .offset: d_buf_offset, .size: d_sz * ne22 * ne23 }, ids: { .buffer: d_ids, .offset: ids_buf_offset, .size: ids_sz }, |
| 7444 | m: ne01, n: ne21, k: ne10, stride_a: ne10, stride_b: ne10, stride_d: ne01, |
| 7445 | batch_stride_a: stride_batch_x, batch_stride_b: stride_batch_y, batch_stride_d: ne20*ne21, |
| 7446 | n_as, nei0, nei1, nbi1: nbi1 / ggml_type_size(type: ids->type), ne11, padded_n |
| 7447 | ); // NOLINT |
| 7448 | |
| 7449 | if (x_non_contig || qx_needs_dequant) { |
| 7450 | ctx->prealloc_x_need_sync = true; |
| 7451 | } |
| 7452 | if (y_non_contig) { |
| 7453 | ctx->prealloc_y_need_sync = true; |
| 7454 | } |
| 7455 | } |
| 7456 | |
| 7457 | static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) { |
| 7458 | ggml_tensor * dst = cgraph->nodes[node_idx]; |
| 7459 | ggml_tensor * src0 = dst->src[0]; |
| 7460 | ggml_tensor * src1 = dst->src[1]; |
| 7461 | ggml_tensor * ids = dst->src[2]; |
| 7462 | VK_LOG_DEBUG("ggml_vk_mul_mat_vec_id_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; |
| 7463 | std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; |
| 7464 | std::cerr << "), (" << ids << ", name=" << ids->name << ", type=" << ids->type << ", ne0=" << ids->ne[0] << ", ne1=" << ids->ne[1] << ", ne2=" << ids->ne[2] << ", ne3=" << ids->ne[3] << ", nb0=" << ids->nb[0] << ", nb1=" << ids->nb[1] << ", nb2=" << ids->nb[2] << ", nb3=" << ids->nb[3]; |
| 7465 | std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; |
| 7466 | std::cerr << "))" ); |
| 7467 | GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16); // NOLINT |
| 7468 | GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT |
| 7469 | GGML_ASSERT(ids->type == GGML_TYPE_I32); |
| 7470 | |
| 7471 | const uint64_t ne00 = src0->ne[0]; |
| 7472 | const uint64_t ne01 = src0->ne[1]; |
| 7473 | const uint64_t ne02 = src0->ne[2]; |
| 7474 | const uint64_t ne03 = src0->ne[3]; |
| 7475 | |
| 7476 | const uint64_t ne10 = src1->ne[0]; |
| 7477 | const uint64_t ne11 = src1->ne[1]; |
| 7478 | const uint64_t ne12 = src1->ne[2]; |
| 7479 | const uint64_t ne13 = src1->ne[3]; |
| 7480 | |
| 7481 | const uint64_t nei0 = ids->ne[0]; |
| 7482 | const uint64_t nei1 = ids->ne[1]; |
| 7483 | |
| 7484 | const uint64_t nbi2 = ids->nb[2]; |
| 7485 | |
| 7486 | GGML_ASSERT(nei1 == 1); |
| 7487 | |
| 7488 | const uint64_t ne20 = dst->ne[0]; |
| 7489 | const uint64_t ne21 = dst->ne[1]; |
| 7490 | const uint64_t ne22 = dst->ne[2]; |
| 7491 | const uint64_t ne23 = dst->ne[3]; |
| 7492 | |
| 7493 | ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; |
| 7494 | ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context; |
| 7495 | ggml_backend_vk_buffer_context * ids_buf_ctx = (ggml_backend_vk_buffer_context *)ids->buffer->context; |
| 7496 | |
| 7497 | vk_buffer d_Qx = nullptr; |
| 7498 | size_t qx_buf_offset = 0; |
| 7499 | vk_buffer d_Qy = nullptr; |
| 7500 | size_t qy_buf_offset = 0; |
| 7501 | vk_buffer d_ids = nullptr; |
| 7502 | size_t ids_buf_offset = 0; |
| 7503 | |
| 7504 | bool src0_uma = false; |
| 7505 | bool src1_uma = false; |
| 7506 | bool ids_uma = false; |
| 7507 | |
| 7508 | if (ctx->device->uma) { |
| 7509 | ggml_vk_host_get(device: ctx->device, ptr: src0->data, buf&: d_Qx, buf_offset&: qx_buf_offset); |
| 7510 | ggml_vk_host_get(device: ctx->device, ptr: src1->data, buf&: d_Qy, buf_offset&: qy_buf_offset); |
| 7511 | ggml_vk_host_get(device: ctx->device, ptr: ids->data, buf&: d_ids, buf_offset&: ids_buf_offset); |
| 7512 | src0_uma = d_Qx != nullptr; |
| 7513 | src1_uma = d_Qy != nullptr; |
| 7514 | ids_uma = d_ids != nullptr; |
| 7515 | } |
| 7516 | |
| 7517 | const bool x_non_contig = !ggml_vk_dim01_contiguous(tensor: src0); |
| 7518 | const bool y_non_contig = !ggml_vk_dim01_contiguous(tensor: src1); |
| 7519 | |
| 7520 | const bool f16_f32_kernel = src1->type == GGML_TYPE_F32; |
| 7521 | |
| 7522 | const bool qx_needs_dequant = x_non_contig; |
| 7523 | const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !f16_f32_kernel) || y_non_contig; |
| 7524 | |
| 7525 | // Not implemented |
| 7526 | GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT |
| 7527 | |
| 7528 | const uint64_t x_ne = ne01 * ne00; |
| 7529 | const uint64_t y_ne = ne11 * ne10; |
| 7530 | const uint64_t d_ne = ne21 * ne20; |
| 7531 | |
| 7532 | const uint64_t qx_sz = ggml_vk_align_size(width: ggml_type_size(type: src0->type) * x_ne / ggml_blck_size(type: src0->type), align: ctx->device->properties.limits.minStorageBufferOffsetAlignment); |
| 7533 | const uint64_t qy_sz = ggml_type_size(type: src1->type) * y_ne / ggml_blck_size(type: src1->type); |
| 7534 | const uint64_t x_sz = x_non_contig ? ggml_vk_align_size(width: ggml_type_size(type: src0->type) * x_ne, align: ctx->device->properties.limits.minStorageBufferOffsetAlignment) : qx_sz; |
| 7535 | const uint64_t y_sz = f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne; |
| 7536 | const uint64_t ids_sz = nbi2; |
| 7537 | const uint64_t d_sz = sizeof(float) * d_ne; |
| 7538 | |
| 7539 | vk_pipeline to_fp16_vk_0 = nullptr; |
| 7540 | vk_pipeline to_fp16_vk_1 = nullptr; |
| 7541 | if (x_non_contig) { |
| 7542 | to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src: src0, dst: nullptr, to: src0->type); |
| 7543 | } |
| 7544 | if (y_non_contig) { |
| 7545 | to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src: src1, dst: nullptr, to: src1->type); |
| 7546 | } else { |
| 7547 | to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, type: src1->type); |
| 7548 | } |
| 7549 | vk_pipeline dmmv = ggml_vk_get_dequantize_mul_mat_vec_id(ctx, a_type: src0->type, b_type: src1->type); |
| 7550 | GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT |
| 7551 | GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT |
| 7552 | GGML_ASSERT(dmmv != nullptr); |
| 7553 | |
| 7554 | { |
| 7555 | const uint64_t x_sz_upd = x_sz * ne02 * ne03; |
| 7556 | const uint64_t y_sz_upd = y_sz * ne12 * ne13; |
| 7557 | if ( |
| 7558 | (qx_needs_dequant && x_sz_upd > ctx->device->properties.limits.maxStorageBufferRange) || |
| 7559 | (qy_needs_dequant && y_sz_upd > ctx->device->properties.limits.maxStorageBufferRange)) { |
| 7560 | GGML_ABORT("Requested preallocation size is too large" ); |
| 7561 | } |
| 7562 | if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) { |
| 7563 | ctx->prealloc_size_x = x_sz_upd; |
| 7564 | ggml_vk_preallocate_buffers(ctx, subctx); |
| 7565 | } |
| 7566 | if (qy_needs_dequant && ctx->prealloc_size_y < y_sz_upd) { |
| 7567 | ctx->prealloc_size_y = y_sz_upd; |
| 7568 | ggml_vk_preallocate_buffers(ctx, subctx); |
| 7569 | } |
| 7570 | |
| 7571 | // Request descriptor sets |
| 7572 | if (qx_needs_dequant) { |
| 7573 | ggml_pipeline_request_descriptor_sets(ctx, pipeline&: to_fp16_vk_0, n: 1); |
| 7574 | } |
| 7575 | if (qy_needs_dequant) { |
| 7576 | ggml_pipeline_request_descriptor_sets(ctx, pipeline&: to_fp16_vk_1, n: 1); |
| 7577 | } |
| 7578 | ggml_pipeline_request_descriptor_sets(ctx, pipeline&: dmmv, n: 1); |
| 7579 | } |
| 7580 | |
| 7581 | vk_buffer d_D; |
| 7582 | uint64_t d_buf_offset = 0; |
| 7583 | |
| 7584 | if (ctx->num_additional_fused_ops > 0) { |
| 7585 | const ggml_tensor * add = cgraph->nodes[node_idx + 1]; |
| 7586 | ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)add->buffer->context; |
| 7587 | d_D = dst_buf_ctx->dev_buffer; |
| 7588 | d_buf_offset = vk_tensor_offset(tensor: add) + add->view_offs; |
| 7589 | } else { |
| 7590 | ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; |
| 7591 | d_D = dst_buf_ctx->dev_buffer; |
| 7592 | d_buf_offset = vk_tensor_offset(tensor: dst) + dst->view_offs; |
| 7593 | } |
| 7594 | |
| 7595 | GGML_ASSERT(d_D != nullptr); |
| 7596 | vk_buffer d_X; |
| 7597 | uint64_t x_buf_offset = 0; |
| 7598 | vk_buffer d_Y; |
| 7599 | uint64_t y_buf_offset = 0; |
| 7600 | if(!src0_uma) { |
| 7601 | d_Qx = src0_buf_ctx->dev_buffer; |
| 7602 | qx_buf_offset = vk_tensor_offset(tensor: src0) + src0->view_offs; |
| 7603 | GGML_ASSERT(d_Qx != nullptr); |
| 7604 | } |
| 7605 | if(!src1_uma) { |
| 7606 | d_Qy = src1_buf_ctx->dev_buffer; |
| 7607 | qy_buf_offset = vk_tensor_offset(tensor: src1) + src1->view_offs; |
| 7608 | GGML_ASSERT(d_Qy != nullptr); |
| 7609 | } |
| 7610 | if(!ids_uma) { |
| 7611 | d_ids = ids_buf_ctx->dev_buffer; |
| 7612 | ids_buf_offset = vk_tensor_offset(tensor: ids) + ids->view_offs; |
| 7613 | GGML_ASSERT(d_ids != nullptr); |
| 7614 | } |
| 7615 | if (qx_needs_dequant) { |
| 7616 | d_X = ctx->prealloc_x; |
| 7617 | } else { |
| 7618 | d_X = d_Qx; |
| 7619 | x_buf_offset = qx_buf_offset; |
| 7620 | GGML_ASSERT(qx_sz == x_sz); |
| 7621 | } |
| 7622 | if (qy_needs_dequant) { |
| 7623 | d_Y = ctx->prealloc_y; |
| 7624 | } else { |
| 7625 | d_Y = d_Qy; |
| 7626 | y_buf_offset = qy_buf_offset; |
| 7627 | GGML_ASSERT(qy_sz == y_sz); |
| 7628 | } |
| 7629 | |
| 7630 | if (x_non_contig) { |
| 7631 | if (ctx->prealloc_x_need_sync) { |
| 7632 | ggml_vk_sync_buffers(ctx, subctx); |
| 7633 | } |
| 7634 | } |
| 7635 | |
| 7636 | if (x_non_contig) { |
| 7637 | GGML_ASSERT(x_sz == ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment)); |
| 7638 | ggml_vk_cpy_to_contiguous(ctx, subctx, pipeline: to_fp16_vk_0, tensor: src0, in: ggml_vk_subbuffer(ctx, buf: d_Qx, offset: qx_buf_offset), out: ggml_vk_subbuffer(ctx, buf: d_X, offset: 0)); |
| 7639 | } |
| 7640 | if (y_non_contig) { |
| 7641 | GGML_ASSERT(y_sz == ggml_type_size(src1->type) * y_ne); |
| 7642 | if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() || |
| 7643 | ctx->prealloc_y_last_tensor_used != src1) { |
| 7644 | if (ctx->prealloc_y_need_sync) { |
| 7645 | ggml_vk_sync_buffers(ctx, subctx); |
| 7646 | } |
| 7647 | ggml_vk_cpy_to_contiguous(ctx, subctx, pipeline: to_fp16_vk_1, tensor: src1, in: ggml_vk_subbuffer(ctx, buf: d_Qy, offset: qy_buf_offset), out: ggml_vk_subbuffer(ctx, buf: d_Y, offset: 0)); |
| 7648 | ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get(); |
| 7649 | ctx->prealloc_y_last_tensor_used = src1; |
| 7650 | } |
| 7651 | } |
| 7652 | |
| 7653 | uint32_t stride_batch_y = ne10*ne11; |
| 7654 | |
| 7655 | if (!ggml_vk_dim01_contiguous(tensor: src1) && !qy_needs_dequant) { |
| 7656 | stride_batch_y = src1->nb[0] / ggml_type_size(type: src1->type); |
| 7657 | } |
| 7658 | |
| 7659 | const uint32_t max_groups_x = ctx->device->properties.limits.maxComputeWorkGroupCount[0]; |
| 7660 | |
| 7661 | uint32_t groups_x = ne01; |
| 7662 | uint32_t groups_z = 1; |
| 7663 | |
| 7664 | if (ne01 > max_groups_x) { |
| 7665 | groups_z = 64; |
| 7666 | groups_x = CEIL_DIV(groups_x, groups_z); |
| 7667 | } |
| 7668 | |
| 7669 | uint32_t enable_bias = ctx->num_additional_fused_ops > 0; |
| 7670 | |
| 7671 | vk_buffer d_B = d_D; |
| 7672 | size_t b_buf_offset = 0; |
| 7673 | uint64_t b_sz = 0; |
| 7674 | |
| 7675 | if (enable_bias) { |
| 7676 | const ggml_tensor * bias = cgraph->nodes[node_idx + 1]->src[1]; |
| 7677 | |
| 7678 | bool b_uma = false; |
| 7679 | if (ctx->device->uma) { |
| 7680 | ggml_vk_host_get(device: ctx->device, ptr: bias->data, buf&: d_B, buf_offset&: b_buf_offset); |
| 7681 | b_uma = d_B != nullptr; |
| 7682 | } |
| 7683 | if(!b_uma) { |
| 7684 | ggml_backend_vk_buffer_context * bias_buf_ctx = (ggml_backend_vk_buffer_context *)bias->buffer->context; |
| 7685 | d_B = bias_buf_ctx->dev_buffer; |
| 7686 | b_buf_offset = vk_tensor_offset(tensor: bias) + bias->view_offs; |
| 7687 | GGML_ASSERT(d_B != nullptr); |
| 7688 | b_sz = ggml_nbytes(tensor: bias); |
| 7689 | } |
| 7690 | } |
| 7691 | |
| 7692 | // compute |
| 7693 | const vk_mat_vec_id_push_constants pc = { |
| 7694 | .ncols: (uint32_t)ne00, .stride_a: (uint32_t)ne10, .stride_b: (uint32_t)ne10, .stride_d: (uint32_t)ne01, |
| 7695 | .batch_stride_a: (uint32_t)x_ne, .batch_stride_b: stride_batch_y, .batch_stride_d: (uint32_t)(ne20*ne21), |
| 7696 | |
| 7697 | .enable_bias: enable_bias, |
| 7698 | |
| 7699 | .nei0: (uint32_t)nei0, .ne11: (uint32_t)ne11, |
| 7700 | }; |
| 7701 | ggml_vk_dispatch_pipeline(ctx, subctx, pipeline&: dmmv, |
| 7702 | descriptor_buffer_infos: { |
| 7703 | vk_subbuffer{ .buffer: d_X, .offset: x_buf_offset, .size: x_sz * ne02 * ne03 }, |
| 7704 | vk_subbuffer{ .buffer: d_Y, .offset: y_buf_offset, .size: y_sz * ne12 * ne13 }, |
| 7705 | vk_subbuffer{ .buffer: d_D, .offset: d_buf_offset, .size: d_sz * ne22 * ne23}, |
| 7706 | vk_subbuffer{ .buffer: d_B, .offset: b_buf_offset, .size: b_sz }, |
| 7707 | vk_subbuffer{ .buffer: d_ids, .offset: ids_buf_offset, .size: ids_sz }, |
| 7708 | }, |
| 7709 | push_constants: pc, elements: { groups_x, (uint32_t)nei0, groups_z }); |
| 7710 | |
| 7711 | if (x_non_contig) { |
| 7712 | ctx->prealloc_x_need_sync = true; |
| 7713 | } |
| 7714 | if (y_non_contig) { |
| 7715 | ctx->prealloc_y_need_sync = true; |
| 7716 | } |
| 7717 | } |
| 7718 | |
| 7719 | static bool ggml_vk_use_mul_mat_vec_id(const struct ggml_cgraph * cgraph, int node_idx) { |
| 7720 | ggml_tensor * dst = cgraph->nodes[node_idx]; |
| 7721 | ggml_tensor * src0 = dst->src[0]; |
| 7722 | ggml_tensor * src2 = dst->src[2]; |
| 7723 | return src2->ne[1] == 1 && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(type: src0->type)); |
| 7724 | } |
| 7725 | |
| 7726 | static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) { |
| 7727 | ggml_tensor * dst = cgraph->nodes[node_idx]; |
| 7728 | ggml_tensor * src0 = dst->src[0]; |
| 7729 | ggml_tensor * src1 = dst->src[1]; |
| 7730 | ggml_tensor * src2 = dst->src[2]; |
| 7731 | VK_LOG_DEBUG("ggml_vk_mul_mat_id(" << src0 << ", " << src1 << ", " << src2 << ", " << dst << ")" ); |
| 7732 | if (ggml_vk_use_mul_mat_vec_id(cgraph, node_idx)) { |
| 7733 | ggml_vk_mul_mat_vec_id_q_f16(ctx, subctx, cgraph, node_idx); |
| 7734 | } else { |
| 7735 | ggml_vk_mul_mat_id_q_f16(ctx, subctx, src0, src1, ids: src2, dst); |
| 7736 | } |
| 7737 | } |
| 7738 | |
| 7739 | static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv) { |
| 7740 | // Needs to be kept up to date on shader changes |
| 7741 | GGML_UNUSED(hsv); |
| 7742 | const uint32_t wg_size = scalar_flash_attention_workgroup_size; |
| 7743 | const uint32_t Br = get_fa_scalar_num_large_rows(hsv); |
| 7744 | const uint32_t Bc = scalar_flash_attention_Bc; |
| 7745 | |
| 7746 | const uint32_t tmpsh = wg_size * sizeof(float); |
| 7747 | const uint32_t tmpshv4 = wg_size * 4 * sizeof(float); |
| 7748 | |
| 7749 | const uint32_t masksh = Bc * Br * sizeof(float); |
| 7750 | |
| 7751 | const uint32_t Qf = Br * (hsk / 4 + 2) * 4 * sizeof(float); |
| 7752 | |
| 7753 | const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf; |
| 7754 | const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize; |
| 7755 | |
| 7756 | VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", total_size=" << total_size << ", supported=" << supported); |
| 7757 | |
| 7758 | return supported; |
| 7759 | } |
| 7760 | |
| 7761 | static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, bool f32acc) { |
| 7762 | // Needs to be kept up to date on shader changes |
| 7763 | GGML_UNUSED(hsv); |
| 7764 | const uint32_t wg_size = scalar_flash_attention_workgroup_size; |
| 7765 | const uint32_t Br = coopmat1_flash_attention_num_large_rows; |
| 7766 | const uint32_t Bc = scalar_flash_attention_Bc; |
| 7767 | |
| 7768 | const uint32_t hsk_pad = ROUNDUP_POW2(hsk, 16); |
| 7769 | |
| 7770 | const uint32_t acctype = f32acc ? 4 : 2; |
| 7771 | const uint32_t f16vec4 = 8; |
| 7772 | |
| 7773 | const uint32_t tmpsh = wg_size * sizeof(float); |
| 7774 | const uint32_t tmpshv4 = wg_size * 4 * acctype; |
| 7775 | |
| 7776 | const uint32_t qstride = hsk_pad / 4 + 2; |
| 7777 | const uint32_t Qf = Br * qstride * f16vec4; |
| 7778 | |
| 7779 | const uint32_t sfshstride = (hsk <= 128) ? (Br + 8) : Br; |
| 7780 | const uint32_t sfsh = Bc * sfshstride * acctype; |
| 7781 | |
| 7782 | const uint32_t kshstride = hsk_pad / 4 + 2; |
| 7783 | const uint32_t ksh = Bc * kshstride * f16vec4; |
| 7784 | |
| 7785 | const uint32_t slope = Br * sizeof(float); |
| 7786 | |
| 7787 | const uint32_t total_size = tmpsh + tmpshv4 + Qf + sfsh + ksh + slope; |
| 7788 | const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize; |
| 7789 | |
| 7790 | VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", f32acc=" << f32acc << ", total_size=" << total_size << ", supported=" << supported); |
| 7791 | |
| 7792 | return supported; |
| 7793 | } |
| 7794 | |
| 7795 | static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * q, const ggml_tensor * k, const ggml_tensor * v, const ggml_tensor * mask, const ggml_tensor * sinks, ggml_tensor * dst) { |
| 7796 | VK_LOG_DEBUG("ggml_vk_flash_attn((" << q << ", name=" << q->name << ", type=" << q->type << ", ne0=" << q->ne[0] << ", ne1=" << q->ne[1] << ", ne2=" << q->ne[2] << ", ne3=" << q->ne[3] << ", nb0=" << q->nb[0] << ", nb1=" << q->nb[1] << ", nb2=" << q->nb[2] << ", nb3=" << q->nb[3]; |
| 7797 | std::cerr << "), (" << k << ", name=" << k->name << ", type=" << k->type << ", ne0=" << k->ne[0] << ", ne1=" << k->ne[1] << ", ne2=" << k->ne[2] << ", ne3=" << k->ne[3] << ", nb0=" << k->nb[0] << ", nb1=" << k->nb[1] << ", nb2=" << k->nb[2] << ", nb3=" << k->nb[3]; |
| 7798 | std::cerr << "), (" << v << ", name=" << v->name << ", type=" << v->type << ", ne0=" << v->ne[0] << ", ne1=" << v->ne[1] << ", ne2=" << v->ne[2] << ", ne3=" << v->ne[3] << ", nb0=" << v->nb[0] << ", nb1=" << v->nb[1] << ", nb2=" << v->nb[2] << ", nb3=" << v->nb[3]; |
| 7799 | std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; |
| 7800 | if (sinks) { |
| 7801 | std::cerr << "), (" << sinks << ", name=" << sinks->name << ", type=" << sinks->type << ", ne0=" << sinks->ne[0] << ", ne1=" << sinks->ne[1] << ", ne2=" << sinks->ne[2] << ", ne3=" << sinks->ne[3] << ", nb0=" << sinks->nb[0] << ", nb1=" << sinks->nb[1] << ", nb2=" << sinks->nb[2] << ", nb3=" << sinks->nb[3]; |
| 7802 | } |
| 7803 | std::cerr << "))" ); |
| 7804 | |
| 7805 | GGML_TENSOR_LOCALS(int64_t, neq, q, ne) |
| 7806 | GGML_TENSOR_LOCALS(size_t, nbq, q, nb) |
| 7807 | GGML_TENSOR_LOCALS(int64_t, nek, k, ne) |
| 7808 | GGML_TENSOR_LOCALS(size_t, nbk, k, nb) |
| 7809 | GGML_TENSOR_LOCALS(int64_t, nev, v, ne) |
| 7810 | GGML_TENSOR_LOCALS(size_t, nbv, v, nb) |
| 7811 | GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) |
| 7812 | GGML_TENSOR_LOCALS(size_t, nb, dst, nb) |
| 7813 | |
| 7814 | const uint32_t nem1 = mask ? mask->ne[1] : 0; |
| 7815 | const uint32_t nem2 = mask ? mask->ne[2] : 0; |
| 7816 | const uint32_t nem3 = mask ? mask->ne[3] : 0; |
| 7817 | |
| 7818 | const uint32_t HSK = nek0; |
| 7819 | const uint32_t HSV = nev0; |
| 7820 | uint32_t N = neq1; |
| 7821 | const uint32_t KV = nek1; |
| 7822 | |
| 7823 | GGML_ASSERT(ne0 == HSV); |
| 7824 | GGML_ASSERT(ne2 == N); |
| 7825 | |
| 7826 | // input tensor rows must be contiguous |
| 7827 | GGML_ASSERT(nbq0 == ggml_type_size(q->type)); |
| 7828 | GGML_ASSERT(nbk0 == ggml_type_size(k->type)); |
| 7829 | GGML_ASSERT(nbv0 == ggml_type_size(v->type)); |
| 7830 | |
| 7831 | GGML_ASSERT(neq0 == HSK); |
| 7832 | |
| 7833 | GGML_ASSERT(neq1 == N); |
| 7834 | |
| 7835 | GGML_ASSERT(nev1 == nek1); |
| 7836 | |
| 7837 | // dst cannot be transposed or permuted |
| 7838 | GGML_ASSERT(nb0 == sizeof(float)); |
| 7839 | GGML_ASSERT(nb0 <= nb1); |
| 7840 | GGML_ASSERT(nb1 <= nb2); |
| 7841 | GGML_ASSERT(nb2 <= nb3); |
| 7842 | |
| 7843 | assert(dst->type == GGML_TYPE_F32); |
| 7844 | assert(q->type == GGML_TYPE_F32); |
| 7845 | assert(k->type == v->type); |
| 7846 | |
| 7847 | FaCodePath path = ctx->device->coopmat2 ? FA_COOPMAT2 : |
| 7848 | ctx->device->coopmat1_fa_support ? FA_COOPMAT1 : FA_SCALAR; |
| 7849 | |
| 7850 | if (path == FA_COOPMAT1) { |
| 7851 | const bool coopmat_shape_supported = (dst->op_params[3] == GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f32acc) || |
| 7852 | (dst->op_params[3] != GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f16acc); |
| 7853 | |
| 7854 | const bool coopmat_shmem_supported = ggml_vk_flash_attn_coopmat_shmem_support(device: ctx->device, hsk: HSK, hsv: HSV, f32acc: dst->op_params[3] == GGML_PREC_F32); |
| 7855 | |
| 7856 | if (!coopmat_shape_supported || !coopmat_shmem_supported) { |
| 7857 | path = FA_SCALAR; |
| 7858 | } |
| 7859 | } |
| 7860 | |
| 7861 | uint32_t gqa_ratio = 1; |
| 7862 | uint32_t qk_ratio = neq2 / nek2; |
| 7863 | uint32_t workgroups_x = (uint32_t)neq1; |
| 7864 | uint32_t workgroups_y = (uint32_t)neq2; |
| 7865 | uint32_t workgroups_z = (uint32_t)neq3; |
| 7866 | |
| 7867 | // For scalar/coopmat1 FA, we can use the "large" size to accommodate qga. |
| 7868 | // For coopmat2 FA, we always use the small size (which is still pretty large for gqa). |
| 7869 | uint32_t max_gqa; |
| 7870 | switch (path) { |
| 7871 | case FA_SCALAR: |
| 7872 | case FA_COOPMAT1: |
| 7873 | // We may switch from coopmat1 to scalar, so use the scalar limit for both |
| 7874 | max_gqa = get_fa_scalar_num_large_rows(hsv: HSV); |
| 7875 | break; |
| 7876 | case FA_COOPMAT2: |
| 7877 | max_gqa = get_fa_num_small_rows(path: FA_COOPMAT2); |
| 7878 | break; |
| 7879 | default: |
| 7880 | GGML_ASSERT(0); |
| 7881 | } |
| 7882 | |
| 7883 | if (N == 1 && qk_ratio > 1 && qk_ratio <= max_gqa && |
| 7884 | qk_ratio * nek2 == neq2 && nek2 == nev2 && nem2 <= 1) { |
| 7885 | // grouped query attention - make the N dimension equal to gqa_ratio, reduce |
| 7886 | // workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1 |
| 7887 | // and change addressing calculations to index Q's dimension 2. |
| 7888 | gqa_ratio = qk_ratio; |
| 7889 | N = gqa_ratio; |
| 7890 | workgroups_y /= N; |
| 7891 | } |
| 7892 | |
| 7893 | bool small_rows = N <= get_fa_num_small_rows(path); |
| 7894 | |
| 7895 | // coopmat1 does not actually support "small rows" (it needs 16 rows). |
| 7896 | // So use scalar instead. |
| 7897 | if (small_rows && path == FA_COOPMAT1) { |
| 7898 | path = FA_SCALAR; |
| 7899 | } |
| 7900 | |
| 7901 | // scalar is faster than coopmat2 when N==1 |
| 7902 | if (N == 1 && path == FA_COOPMAT2) { |
| 7903 | path = FA_SCALAR; |
| 7904 | } |
| 7905 | |
| 7906 | // with large hsk/hsv, scalar path may need to use small_rows to fit in shared memory |
| 7907 | if (path == FA_SCALAR && |
| 7908 | !ggml_vk_flash_attn_scalar_shmem_support(device: ctx->device, hsk: HSK, hsv: HSV)) { |
| 7909 | small_rows = true; |
| 7910 | } |
| 7911 | |
| 7912 | const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(type: q->type)); |
| 7913 | uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(type: k->type)); |
| 7914 | uint32_t v_stride = (uint32_t)(nbv1 / ggml_type_size(type: v->type)); |
| 7915 | |
| 7916 | // For F32, the shader treats it as a block of size 4 (for vec4 loads) |
| 7917 | if (k->type == GGML_TYPE_F32) { |
| 7918 | k_stride /= 4; |
| 7919 | } |
| 7920 | if (v->type == GGML_TYPE_F32) { |
| 7921 | v_stride /= 4; |
| 7922 | } |
| 7923 | |
| 7924 | uint32_t alignment = fa_align(path, hsk: HSK, hsv: HSV, type: k->type, small_rows); |
| 7925 | bool aligned = (KV % alignment) == 0 && |
| 7926 | // the "aligned" shader variant will forcibly align strides, for performance |
| 7927 | (q_stride & 7) == 0 && (k_stride & 7) == 0 && (v_stride & 7) == 0; |
| 7928 | |
| 7929 | // Need to use the coopmat2 variant that clamps loads when HSK/HSV aren't sufficiently aligned. |
| 7930 | if (((HSK | HSV) % 16) != 0 && path == FA_COOPMAT2) { |
| 7931 | aligned = false; |
| 7932 | } |
| 7933 | |
| 7934 | bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32; |
| 7935 | |
| 7936 | vk_fa_pipeline_state fa_pipeline_state(HSK, HSV, small_rows, path, aligned, f32acc); |
| 7937 | |
| 7938 | vk_pipeline pipeline = nullptr; |
| 7939 | |
| 7940 | { |
| 7941 | std::lock_guard<std::recursive_mutex> guard(ctx->device->mutex); |
| 7942 | auto &pipelines = ctx->device->pipeline_flash_attn_f32_f16[k->type]; |
| 7943 | auto it = pipelines.find(x: fa_pipeline_state); |
| 7944 | if (it != pipelines.end()) { |
| 7945 | pipeline = it->second; |
| 7946 | } else { |
| 7947 | pipelines[fa_pipeline_state] = pipeline = std::make_shared<vk_pipeline_struct>(); |
| 7948 | } |
| 7949 | } |
| 7950 | |
| 7951 | assert(pipeline); |
| 7952 | |
| 7953 | uint32_t split_kv = KV; |
| 7954 | uint32_t split_k = 1; |
| 7955 | |
| 7956 | // Use a placeholder core count if one isn't available. split_k is a big help for perf. |
| 7957 | const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count : 16; |
| 7958 | |
| 7959 | // Try to use split_k when KV is large enough to be worth the overhead |
| 7960 | if (workgroups_x == 1 && shader_core_count > 0) { |
| 7961 | // Try to run two workgroups per SM. |
| 7962 | split_k = shader_core_count * 2 / (workgroups_y * workgroups_z); |
| 7963 | if (split_k > 1) { |
| 7964 | // Try to evenly split KV into split_k chunks, but it needs to be a multiple |
| 7965 | // of "align", so recompute split_k based on that. |
| 7966 | split_kv = ROUNDUP_POW2(std::max(1u, KV / split_k), alignment); |
| 7967 | split_k = CEIL_DIV(KV, split_kv); |
| 7968 | workgroups_x = split_k; |
| 7969 | } |
| 7970 | } |
| 7971 | |
| 7972 | // Reserve space for split_k temporaries. For each split x batch, we need to store the O matrix (D x ne1) |
| 7973 | // and the per-row m and L values (ne1 rows). We store all the matrices first, followed by the rows. |
| 7974 | const uint64_t split_k_size = split_k > 1 ? (HSV * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k * ne3 : 0; |
| 7975 | if (split_k_size > ctx->device->properties.limits.maxStorageBufferRange) { |
| 7976 | GGML_ABORT("Requested preallocation size is too large" ); |
| 7977 | } |
| 7978 | if (ctx->prealloc_size_split_k < split_k_size) { |
| 7979 | ctx->prealloc_size_split_k = split_k_size; |
| 7980 | ggml_vk_preallocate_buffers(ctx, subctx); |
| 7981 | } |
| 7982 | |
| 7983 | { |
| 7984 | // Request descriptor sets |
| 7985 | ggml_pipeline_request_descriptor_sets(ctx, pipeline, n: 1); |
| 7986 | if (split_k > 1) { |
| 7987 | ggml_pipeline_request_descriptor_sets(ctx, pipeline&: ctx->device->pipeline_flash_attn_split_k_reduce, n: 1); |
| 7988 | } |
| 7989 | } |
| 7990 | |
| 7991 | float scale = 1.0f; |
| 7992 | float max_bias = 0.0f; |
| 7993 | float logit_softcap = 0.0f; |
| 7994 | |
| 7995 | memcpy(dest: &scale, src: (const float *) dst->op_params + 0, n: sizeof(float)); |
| 7996 | memcpy(dest: &max_bias, src: (const float *) dst->op_params + 1, n: sizeof(float)); |
| 7997 | memcpy(dest: &logit_softcap, src: (const float *) dst->op_params + 2, n: sizeof(float)); |
| 7998 | |
| 7999 | if (logit_softcap != 0) { |
| 8000 | scale /= logit_softcap; |
| 8001 | } |
| 8002 | |
| 8003 | const uint32_t n_head_kv = neq2; |
| 8004 | const uint32_t n_head_log2 = 1u << (uint32_t) floorf(x: log2f(x: (float) n_head_kv)); |
| 8005 | const float m0 = powf(x: 2.0f, y: -(max_bias ) / n_head_log2); |
| 8006 | const float m1 = powf(x: 2.0f, y: -(max_bias / 2.0f) / n_head_log2); |
| 8007 | |
| 8008 | vk_subbuffer q_buf = ggml_vk_tensor_subbuffer(ctx, tensor: q); |
| 8009 | vk_subbuffer k_buf = ggml_vk_tensor_subbuffer(ctx, tensor: k); |
| 8010 | vk_subbuffer v_buf = ggml_vk_tensor_subbuffer(ctx, tensor: v); |
| 8011 | vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, tensor: dst); |
| 8012 | vk_subbuffer mask_buf = mask ? ggml_vk_tensor_subbuffer(ctx, tensor: mask) : q_buf; |
| 8013 | vk_subbuffer sinks_buf = sinks ? ggml_vk_tensor_subbuffer(ctx, tensor: sinks) : q_buf; |
| 8014 | |
| 8015 | uint32_t mask_n_head_log2 = ((sinks != nullptr) << 24) | ((mask != nullptr) << 16) | n_head_log2; |
| 8016 | |
| 8017 | const vk_flash_attn_push_constants pc = { .N: N, .KV: KV, |
| 8018 | .ne1: (uint32_t)ne1, .ne2: (uint32_t)ne2, .ne3: (uint32_t)ne3, |
| 8019 | .neq2: (uint32_t)neq2, .neq3: (uint32_t)neq3, |
| 8020 | .nek2: (uint32_t)nek2, .nek3: (uint32_t)nek3, |
| 8021 | .nev2: (uint32_t)nev2, .nev3: (uint32_t)nev3, |
| 8022 | .nem1: nem1, .nem2: nem2, .nem3: nem3, |
| 8023 | .nb01: q_stride, .nb02: (uint32_t)nbq2, .nb03: (uint32_t)nbq3, |
| 8024 | .nb11: k_stride, .nb12: (uint32_t)nbk2, .nb13: (uint32_t)nbk3, |
| 8025 | .nb21: v_stride, .nb22: (uint32_t)nbv2, .nb23: (uint32_t)nbv3, |
| 8026 | .scale: scale, .max_bias: max_bias, .logit_softcap: logit_softcap, |
| 8027 | .mask_n_head_log2: mask_n_head_log2, .m0: m0, .m1: m1, |
| 8028 | .gqa_ratio: gqa_ratio, .split_kv: split_kv, .k_num: split_k }; |
| 8029 | |
| 8030 | if (split_k > 1) { |
| 8031 | if (ctx->prealloc_split_k_need_sync) { |
| 8032 | ggml_vk_sync_buffers(ctx, subctx); |
| 8033 | } |
| 8034 | |
| 8035 | vk_subbuffer split_k_buf = ggml_vk_subbuffer(ctx, buf: ctx->prealloc_split_k, offset: 0); |
| 8036 | ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, |
| 8037 | descriptor_buffer_infos: {q_buf, k_buf, v_buf, mask_buf, sinks_buf, split_k_buf}, |
| 8038 | // We only use split_k when group query attention is enabled, which means |
| 8039 | // there's no more than one tile of rows (i.e. workgroups_x would have been |
| 8040 | // one). We reuse workgroups_x to mean the number of splits, so we need to |
| 8041 | // cancel out the divide by wg_denoms[0]. |
| 8042 | push_constants: pc, elements: { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z }); |
| 8043 | |
| 8044 | ggml_vk_sync_buffers(ctx, subctx); |
| 8045 | const std::array<uint32_t, 5> pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne3, split_k, (sinks != nullptr) }; |
| 8046 | ggml_vk_dispatch_pipeline(ctx, subctx, pipeline&: ctx->device->pipeline_flash_attn_split_k_reduce, |
| 8047 | descriptor_buffer_infos: {split_k_buf, sinks_buf, dst_buf}, |
| 8048 | push_constants: pc2, elements: { (uint32_t)ne1, HSV, (uint32_t)ne3 }); |
| 8049 | ctx->prealloc_split_k_need_sync = true; |
| 8050 | } else { |
| 8051 | ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, |
| 8052 | descriptor_buffer_infos: {q_buf, k_buf, v_buf, mask_buf, sinks_buf, dst_buf}, |
| 8053 | push_constants: pc, elements: { workgroups_x, workgroups_y, workgroups_z }); |
| 8054 | } |
| 8055 | } |
| 8056 | |
| 8057 | static std::array<uint32_t, 3> ggml_vk_get_conv_elements(const ggml_tensor *dst) { |
| 8058 | const ggml_tensor *src0 = dst->src[0]; |
| 8059 | const ggml_tensor *src1 = dst->src[1]; |
| 8060 | |
| 8061 | // src0 - kernel: [KW, KH, Cin, Cout] |
| 8062 | // src1 - input: [W, H, Cin, N] |
| 8063 | // dst - result: [OW, OH, Cout, N] |
| 8064 | |
| 8065 | // Copied from ggml.c: int64_t ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p, int d) |
| 8066 | auto calc_conv_output_size = [](int64_t ins, int64_t ks, int s, int p, int d) -> int64_t { |
| 8067 | return (ins + 2 * p - d * (ks - 1) - 1) / s + 1; |
| 8068 | }; |
| 8069 | // parallelize in {OW/BS_K, OH/BS_NPQ, 1} |
| 8070 | int64_t W = src1->ne[0]; |
| 8071 | int64_t H = src1->ne[1]; |
| 8072 | int64_t KW = src0->ne[0]; |
| 8073 | int64_t KH = src0->ne[1]; |
| 8074 | int64_t Cout = src0->ne[3]; |
| 8075 | int64_t N = src1->ne[3]; |
| 8076 | int64_t OH = calc_conv_output_size(H, KH, dst->op_params[1], dst->op_params[3], dst->op_params[5]); |
| 8077 | int64_t OW = calc_conv_output_size(W, KW, dst->op_params[0], dst->op_params[2], dst->op_params[4]); |
| 8078 | int64_t NPQ = N * OW * OH; |
| 8079 | |
| 8080 | // Tile output matrix to (K/NB_K, NPQ/NB_NPQ, 1) workgroups |
| 8081 | std::array<uint32_t, 3> elements = { static_cast<uint32_t>(Cout), static_cast<uint32_t>(NPQ), 1 }; |
| 8082 | return elements; |
| 8083 | } |
| 8084 | |
| 8085 | static std::array<uint32_t, 3> ggml_vk_get_conv_transpose_2d_elements(const ggml_tensor *dst) { |
| 8086 | const ggml_tensor *src0 = dst->src[0]; |
| 8087 | const ggml_tensor *src1 = dst->src[1]; |
| 8088 | |
| 8089 | // src0 - kernel: [KW, KH, Cout, Cin] |
| 8090 | // src1 - input: [W, H, Cin, N] |
| 8091 | // dst - result: [OW, OH, Cout, N] |
| 8092 | |
| 8093 | auto calc_conv_output_size = [](int64_t ins, int64_t ks, int s, int p, int d) -> int64_t { |
| 8094 | return (ins - 1) * s - 2 * p + (ks - 1) * d + 1; |
| 8095 | }; |
| 8096 | // parallelize in {OW/BS_K, OH/BS_NPQ, 1} |
| 8097 | int64_t W = src1->ne[0]; |
| 8098 | int64_t H = src1->ne[1]; |
| 8099 | int64_t KW = src0->ne[0]; |
| 8100 | int64_t KH = src0->ne[1]; |
| 8101 | int64_t Cout = src0->ne[2]; |
| 8102 | int64_t N = src1->ne[3]; |
| 8103 | int64_t OH = calc_conv_output_size(H, KH, dst->op_params[0], 0, 1); |
| 8104 | int64_t OW = calc_conv_output_size(W, KW, dst->op_params[0], 0, 1); |
| 8105 | int64_t NPQ = N * OW * OH; |
| 8106 | |
| 8107 | // Tile output matrix to (K/NB_K, NPQ/NB_NPQ, 1) workgroups |
| 8108 | std::array<uint32_t, 3> elements = { static_cast<uint32_t>(Cout), static_cast<uint32_t>(NPQ), 1 }; |
| 8109 | return elements; |
| 8110 | } |
| 8111 | |
| 8112 | static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * dst, ggml_op op) { |
| 8113 | switch (op) { |
| 8114 | case GGML_OP_GET_ROWS: |
| 8115 | GGML_ASSERT(src1->type == GGML_TYPE_I32); |
| 8116 | if (dst->type == GGML_TYPE_F16) { |
| 8117 | return ctx->device->pipeline_get_rows[src0->type]; |
| 8118 | } |
| 8119 | if (dst->type == GGML_TYPE_F32) { |
| 8120 | return ctx->device->pipeline_get_rows_f32[src0->type]; |
| 8121 | } |
| 8122 | return nullptr; |
| 8123 | case GGML_OP_ACC: |
| 8124 | if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { |
| 8125 | return ctx->device->pipeline_acc_f32; |
| 8126 | } |
| 8127 | return nullptr; |
| 8128 | case GGML_OP_ADD: |
| 8129 | case GGML_OP_SUB: |
| 8130 | case GGML_OP_MUL: |
| 8131 | case GGML_OP_DIV: |
| 8132 | if ((src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) || |
| 8133 | (src1->type != GGML_TYPE_F32 && src1->type != GGML_TYPE_F16) || |
| 8134 | (dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16)) { |
| 8135 | return nullptr; |
| 8136 | } |
| 8137 | switch (op) { |
| 8138 | case GGML_OP_ADD: |
| 8139 | { |
| 8140 | if (ctx->num_additional_fused_ops > 0) { |
| 8141 | if (ctx->do_add_rms_partials) { |
| 8142 | return ctx->device->pipeline_multi_add_rms[ctx->num_additional_fused_ops]; |
| 8143 | } else { |
| 8144 | return ctx->device->pipeline_multi_add[ctx->num_additional_fused_ops]; |
| 8145 | } |
| 8146 | } |
| 8147 | if (ctx->do_add_rms_partials) { |
| 8148 | auto pipelines = ggml_are_same_shape(t0: src0, t1: src1) ? ctx->device->pipeline_add_rms_norepeat : ctx->device->pipeline_add_rms; |
| 8149 | return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16]; |
| 8150 | } else { |
| 8151 | auto pipelines = ggml_are_same_shape(t0: src0, t1: src1) ? ctx->device->pipeline_add_norepeat : ctx->device->pipeline_add; |
| 8152 | return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16]; |
| 8153 | } |
| 8154 | } |
| 8155 | case GGML_OP_SUB: |
| 8156 | { |
| 8157 | auto pipelines = ggml_are_same_shape(t0: src0, t1: src1) ? ctx->device->pipeline_sub_norepeat : ctx->device->pipeline_sub; |
| 8158 | return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16]; |
| 8159 | } |
| 8160 | case GGML_OP_MUL: |
| 8161 | { |
| 8162 | auto pipelines = ggml_are_same_shape(t0: src0, t1: src1) ? ctx->device->pipeline_mul_norepeat : ctx->device->pipeline_mul; |
| 8163 | return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16]; |
| 8164 | } |
| 8165 | case GGML_OP_DIV: |
| 8166 | { |
| 8167 | auto pipelines = ggml_are_same_shape(t0: src0, t1: src1) ? ctx->device->pipeline_div_norepeat : ctx->device->pipeline_div; |
| 8168 | return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16]; |
| 8169 | } |
| 8170 | default: |
| 8171 | break; |
| 8172 | } |
| 8173 | return nullptr; |
| 8174 | case GGML_OP_ADD_ID: |
| 8175 | if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && src2->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_F32) { |
| 8176 | return ctx->device->pipeline_add_id_f32; |
| 8177 | } |
| 8178 | return nullptr; |
| 8179 | case GGML_OP_CONCAT: |
| 8180 | if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { |
| 8181 | return ctx->device->pipeline_concat_f32; |
| 8182 | } |
| 8183 | if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { |
| 8184 | return ctx->device->pipeline_concat_f16; |
| 8185 | } |
| 8186 | if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I32) { |
| 8187 | return ctx->device->pipeline_concat_i32; |
| 8188 | } |
| 8189 | return nullptr; |
| 8190 | case GGML_OP_UPSCALE: |
| 8191 | if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { |
| 8192 | ggml_scale_mode mode = (ggml_scale_mode)(ggml_get_op_params_i32(tensor: dst, i: 0) & 0xFF); |
| 8193 | switch (mode) { |
| 8194 | case GGML_SCALE_MODE_NEAREST: |
| 8195 | return ctx->device->pipeline_upscale_nearest_f32; |
| 8196 | case GGML_SCALE_MODE_BILINEAR: |
| 8197 | return ctx->device->pipeline_upscale_bilinear_f32; |
| 8198 | default: |
| 8199 | return nullptr; |
| 8200 | } |
| 8201 | } |
| 8202 | return nullptr; |
| 8203 | case GGML_OP_SCALE: |
| 8204 | if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { |
| 8205 | return ctx->device->pipeline_scale_f32; |
| 8206 | } |
| 8207 | return nullptr; |
| 8208 | case GGML_OP_SQR: |
| 8209 | if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { |
| 8210 | return ctx->device->pipeline_sqr_f32; |
| 8211 | } |
| 8212 | return nullptr; |
| 8213 | case GGML_OP_SQRT: |
| 8214 | if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { |
| 8215 | return ctx->device->pipeline_sqrt_f32; |
| 8216 | } |
| 8217 | return nullptr; |
| 8218 | case GGML_OP_SIN: |
| 8219 | if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { |
| 8220 | return ctx->device->pipeline_sin_f32; |
| 8221 | } |
| 8222 | return nullptr; |
| 8223 | case GGML_OP_COS: |
| 8224 | if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { |
| 8225 | return ctx->device->pipeline_cos_f32; |
| 8226 | } |
| 8227 | return nullptr; |
| 8228 | case GGML_OP_CLAMP: |
| 8229 | if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { |
| 8230 | return ctx->device->pipeline_clamp_f32; |
| 8231 | } |
| 8232 | return nullptr; |
| 8233 | case GGML_OP_PAD: |
| 8234 | if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { |
| 8235 | return ctx->device->pipeline_pad_f32; |
| 8236 | } |
| 8237 | return nullptr; |
| 8238 | case GGML_OP_ROLL: |
| 8239 | if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { |
| 8240 | return ctx->device->pipeline_roll_f32; |
| 8241 | } |
| 8242 | return nullptr; |
| 8243 | case GGML_OP_REPEAT: |
| 8244 | if (ggml_type_size(type: src0->type) == sizeof(float) && ggml_type_size(type: dst->type) == sizeof(float)) { |
| 8245 | return ctx->device->pipeline_repeat_f32; |
| 8246 | } |
| 8247 | return nullptr; |
| 8248 | case GGML_OP_REPEAT_BACK: |
| 8249 | if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { |
| 8250 | return ctx->device->pipeline_repeat_back_f32; |
| 8251 | } |
| 8252 | return nullptr; |
| 8253 | case GGML_OP_CPY: |
| 8254 | case GGML_OP_CONT: |
| 8255 | case GGML_OP_DUP: |
| 8256 | return ggml_vk_get_cpy_pipeline(ctx, src: src0, dst, to: dst->type); |
| 8257 | case GGML_OP_SET_ROWS: |
| 8258 | if (src1->type == GGML_TYPE_I64) { |
| 8259 | return ctx->device->pipeline_set_rows_i64[dst->type]; |
| 8260 | } else { |
| 8261 | return ctx->device->pipeline_set_rows_i32[dst->type]; |
| 8262 | } |
| 8263 | case GGML_OP_SILU_BACK: |
| 8264 | if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { |
| 8265 | return ctx->device->pipeline_silu_back_f32; |
| 8266 | } |
| 8267 | return nullptr; |
| 8268 | case GGML_OP_NORM: |
| 8269 | if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { |
| 8270 | return ctx->device->pipeline_norm_f32; |
| 8271 | } |
| 8272 | return nullptr; |
| 8273 | case GGML_OP_GROUP_NORM: |
| 8274 | if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { |
| 8275 | return ctx->device->pipeline_group_norm_f32; |
| 8276 | } |
| 8277 | return nullptr; |
| 8278 | case GGML_OP_RMS_NORM: |
| 8279 | if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { |
| 8280 | if (ctx->do_add_rms_partials) { |
| 8281 | return ctx->num_additional_fused_ops > 0 ? ctx->device->pipeline_rms_norm_mul_partials_f32 : ctx->device->pipeline_rms_norm_partials_f32; |
| 8282 | } else { |
| 8283 | return ctx->num_additional_fused_ops > 0 ? ctx->device->pipeline_rms_norm_mul_f32 : ctx->device->pipeline_rms_norm_f32; |
| 8284 | } |
| 8285 | } |
| 8286 | return nullptr; |
| 8287 | case GGML_OP_RMS_NORM_BACK: |
| 8288 | if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { |
| 8289 | return ctx->device->pipeline_rms_norm_back_f32; |
| 8290 | } |
| 8291 | return nullptr; |
| 8292 | case GGML_OP_L2_NORM: |
| 8293 | if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { |
| 8294 | return ctx->device->pipeline_l2_norm_f32; |
| 8295 | } |
| 8296 | return nullptr; |
| 8297 | case GGML_OP_UNARY: |
| 8298 | if ((src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) || |
| 8299 | (dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16) || |
| 8300 | (src0->type != dst->type)) { |
| 8301 | return nullptr; |
| 8302 | } |
| 8303 | |
| 8304 | switch (ggml_get_unary_op(tensor: dst)) { |
| 8305 | case GGML_UNARY_OP_EXP: |
| 8306 | return ctx->device->pipeline_exp[dst->type == GGML_TYPE_F16]; |
| 8307 | case GGML_UNARY_OP_SILU: |
| 8308 | return ctx->device->pipeline_silu[dst->type == GGML_TYPE_F16]; |
| 8309 | case GGML_UNARY_OP_GELU: |
| 8310 | return ctx->device->pipeline_gelu[dst->type == GGML_TYPE_F16]; |
| 8311 | case GGML_UNARY_OP_GELU_ERF: |
| 8312 | return ctx->device->pipeline_gelu_erf[dst->type == GGML_TYPE_F16]; |
| 8313 | case GGML_UNARY_OP_GELU_QUICK: |
| 8314 | return ctx->device->pipeline_gelu_quick[dst->type == GGML_TYPE_F16]; |
| 8315 | case GGML_UNARY_OP_RELU: |
| 8316 | return ctx->device->pipeline_relu[dst->type == GGML_TYPE_F16]; |
| 8317 | case GGML_UNARY_OP_TANH: |
| 8318 | return ctx->device->pipeline_tanh[dst->type == GGML_TYPE_F16]; |
| 8319 | case GGML_UNARY_OP_SIGMOID: |
| 8320 | return ctx->device->pipeline_sigmoid[dst->type == GGML_TYPE_F16]; |
| 8321 | case GGML_UNARY_OP_HARDSIGMOID: |
| 8322 | return ctx->device->pipeline_hardsigmoid[dst->type == GGML_TYPE_F16]; |
| 8323 | case GGML_UNARY_OP_HARDSWISH: |
| 8324 | return ctx->device->pipeline_hardswish[dst->type == GGML_TYPE_F16]; |
| 8325 | default: |
| 8326 | break; |
| 8327 | } |
| 8328 | return nullptr; |
| 8329 | case GGML_OP_GLU: |
| 8330 | if ((src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) || |
| 8331 | (dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16) || |
| 8332 | (src0->type != dst->type)) { |
| 8333 | return nullptr; |
| 8334 | } |
| 8335 | |
| 8336 | switch (ggml_get_glu_op(tensor: dst)) { |
| 8337 | case GGML_GLU_OP_GEGLU: |
| 8338 | return ctx->device->pipeline_geglu[dst->type == GGML_TYPE_F16]; |
| 8339 | case GGML_GLU_OP_REGLU: |
| 8340 | return ctx->device->pipeline_reglu[dst->type == GGML_TYPE_F16]; |
| 8341 | case GGML_GLU_OP_SWIGLU: |
| 8342 | return ctx->device->pipeline_swiglu[dst->type == GGML_TYPE_F16]; |
| 8343 | case GGML_GLU_OP_SWIGLU_OAI: |
| 8344 | return ctx->device->pipeline_swiglu_oai[dst->type == GGML_TYPE_F16]; |
| 8345 | case GGML_GLU_OP_GEGLU_ERF: |
| 8346 | return ctx->device->pipeline_geglu_erf[dst->type == GGML_TYPE_F16]; |
| 8347 | case GGML_GLU_OP_GEGLU_QUICK: |
| 8348 | return ctx->device->pipeline_geglu_quick[dst->type == GGML_TYPE_F16]; |
| 8349 | default: |
| 8350 | break; |
| 8351 | } |
| 8352 | return nullptr; |
| 8353 | case GGML_OP_DIAG_MASK_INF: |
| 8354 | if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { |
| 8355 | return ctx->device->pipeline_diag_mask_inf_f32; |
| 8356 | } |
| 8357 | return nullptr; |
| 8358 | case GGML_OP_SOFT_MAX: |
| 8359 | GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); |
| 8360 | GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32); |
| 8361 | |
| 8362 | if (ctx->num_additional_fused_ops) { |
| 8363 | uint32_t idx = (uint32_t)ceilf(x: log2f(x: float(dst->ne[0]))); |
| 8364 | GGML_ASSERT(idx < num_topk_moe_pipelines); |
| 8365 | topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(num: ctx->num_additional_fused_ops); |
| 8366 | return ctx->device->pipeline_topk_moe[idx][mode]; |
| 8367 | } |
| 8368 | |
| 8369 | if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) { |
| 8370 | return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_wg512 : ctx->device->pipeline_soft_max_f32; |
| 8371 | } |
| 8372 | if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) { |
| 8373 | return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_f16_wg512 : ctx->device->pipeline_soft_max_f32_f16; |
| 8374 | } |
| 8375 | return nullptr; |
| 8376 | case GGML_OP_SOFT_MAX_BACK: |
| 8377 | if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { |
| 8378 | return ctx->device->pipeline_soft_max_back_f32; |
| 8379 | } |
| 8380 | return nullptr; |
| 8381 | case GGML_OP_ROPE: |
| 8382 | case GGML_OP_ROPE_BACK: |
| 8383 | { |
| 8384 | const ggml_tensor *rope = ctx->num_additional_fused_ops == 2 ? dst->src[0]->src[0] : dst; |
| 8385 | const int mode = ((const int32_t *) rope->op_params)[2]; |
| 8386 | const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; |
| 8387 | const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; |
| 8388 | const bool is_vision = mode == GGML_ROPE_TYPE_VISION; |
| 8389 | |
| 8390 | if (is_neox) { |
| 8391 | if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { |
| 8392 | return ctx->device->pipeline_rope_neox_f32; |
| 8393 | } |
| 8394 | if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) { |
| 8395 | return ctx->device->pipeline_rope_neox_f32_f16; |
| 8396 | } |
| 8397 | if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { |
| 8398 | return ctx->device->pipeline_rope_neox_f16; |
| 8399 | } |
| 8400 | } else if (is_mrope && !is_vision) { |
| 8401 | if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { |
| 8402 | return ctx->device->pipeline_rope_multi_f32; |
| 8403 | } |
| 8404 | if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { |
| 8405 | return ctx->device->pipeline_rope_multi_f16; |
| 8406 | } |
| 8407 | } else if (is_vision) { |
| 8408 | if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { |
| 8409 | return ctx->device->pipeline_rope_vision_f32; |
| 8410 | } |
| 8411 | if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { |
| 8412 | return ctx->device->pipeline_rope_vision_f16; |
| 8413 | } |
| 8414 | } else { |
| 8415 | if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { |
| 8416 | return ctx->device->pipeline_rope_norm_f32; |
| 8417 | } |
| 8418 | if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) { |
| 8419 | return ctx->device->pipeline_rope_norm_f32_f16; |
| 8420 | } |
| 8421 | if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { |
| 8422 | return ctx->device->pipeline_rope_norm_f16; |
| 8423 | } |
| 8424 | } |
| 8425 | return nullptr; |
| 8426 | } |
| 8427 | case GGML_OP_ARGSORT: |
| 8428 | if (ctx->num_additional_fused_ops) { |
| 8429 | uint32_t idx = (uint32_t)ceilf(x: log2f(x: float(dst->ne[0]))); |
| 8430 | GGML_ASSERT(idx < num_topk_moe_pipelines); |
| 8431 | topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(num: ctx->num_additional_fused_ops); |
| 8432 | return ctx->device->pipeline_topk_moe[idx][mode]; |
| 8433 | } |
| 8434 | |
| 8435 | if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) { |
| 8436 | uint32_t idx = (uint32_t)ceilf(x: log2f(x: float(dst->ne[0]))); |
| 8437 | return ctx->device->pipeline_argsort_f32[idx]; |
| 8438 | } |
| 8439 | return nullptr; |
| 8440 | case GGML_OP_SUM: |
| 8441 | case GGML_OP_SUM_ROWS: |
| 8442 | case GGML_OP_MEAN: |
| 8443 | if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { |
| 8444 | return ctx->device->pipeline_sum_rows_f32; |
| 8445 | } |
| 8446 | return nullptr; |
| 8447 | case GGML_OP_ARGMAX: |
| 8448 | if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) { |
| 8449 | return ctx->device->pipeline_argmax_f32; |
| 8450 | } |
| 8451 | return nullptr; |
| 8452 | case GGML_OP_COUNT_EQUAL: |
| 8453 | if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I64) { |
| 8454 | return ctx->device->pipeline_count_equal_i32; |
| 8455 | } |
| 8456 | return nullptr; |
| 8457 | case GGML_OP_IM2COL: |
| 8458 | if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { |
| 8459 | return ctx->device->pipeline_im2col_f32; |
| 8460 | } |
| 8461 | if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) { |
| 8462 | return ctx->device->pipeline_im2col_f32_f16; |
| 8463 | } |
| 8464 | return nullptr; |
| 8465 | case GGML_OP_IM2COL_3D: |
| 8466 | if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { |
| 8467 | return ctx->device->pipeline_im2col_3d_f32; |
| 8468 | } |
| 8469 | if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) { |
| 8470 | return ctx->device->pipeline_im2col_3d_f32_f16; |
| 8471 | } |
| 8472 | return nullptr; |
| 8473 | case GGML_OP_TIMESTEP_EMBEDDING: |
| 8474 | if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { |
| 8475 | return ctx->device->pipeline_timestep_embedding_f32; |
| 8476 | } |
| 8477 | return nullptr; |
| 8478 | case GGML_OP_CONV_TRANSPOSE_1D: |
| 8479 | if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { |
| 8480 | return ctx->device->pipeline_conv_transpose_1d_f32; |
| 8481 | } |
| 8482 | return nullptr; |
| 8483 | case GGML_OP_POOL_2D: |
| 8484 | if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { |
| 8485 | return ctx->device->pipeline_pool2d_f32; |
| 8486 | } |
| 8487 | return nullptr; |
| 8488 | case GGML_OP_RWKV_WKV6: |
| 8489 | if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { |
| 8490 | return ctx->device->pipeline_rwkv_wkv6_f32; |
| 8491 | } |
| 8492 | return nullptr; |
| 8493 | case GGML_OP_RWKV_WKV7: |
| 8494 | if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { |
| 8495 | return ctx->device->pipeline_rwkv_wkv7_f32; |
| 8496 | } |
| 8497 | return nullptr; |
| 8498 | case GGML_OP_SSM_SCAN: |
| 8499 | if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { |
| 8500 | const uint32_t d_state = src0->ne[0]; |
| 8501 | if (d_state == 128) { |
| 8502 | return ctx->device->pipeline_ssm_scan_f32_d128; |
| 8503 | } else if (d_state == 256) { |
| 8504 | return ctx->device->pipeline_ssm_scan_f32_d256; |
| 8505 | } |
| 8506 | } |
| 8507 | return nullptr; |
| 8508 | case GGML_OP_SSM_CONV: |
| 8509 | if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { |
| 8510 | return ctx->device->pipeline_ssm_conv_f32; |
| 8511 | } |
| 8512 | return nullptr; |
| 8513 | case GGML_OP_OPT_STEP_ADAMW: |
| 8514 | if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { |
| 8515 | return ctx->device->pipeline_opt_step_adamw_f32; |
| 8516 | } |
| 8517 | return nullptr; |
| 8518 | case GGML_OP_OPT_STEP_SGD: |
| 8519 | if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { |
| 8520 | return ctx->device->pipeline_opt_step_sgd_f32; |
| 8521 | } |
| 8522 | return nullptr; |
| 8523 | case GGML_OP_LEAKY_RELU: |
| 8524 | if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { |
| 8525 | return ctx->device->pipeline_leaky_relu_f32; |
| 8526 | } |
| 8527 | return nullptr; |
| 8528 | case GGML_OP_CONV_2D: |
| 8529 | case GGML_OP_CONV_TRANSPOSE_2D: |
| 8530 | if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && |
| 8531 | ggml_is_contiguous(tensor: src0) && ggml_is_contiguous(tensor: src1) && ggml_is_contiguous(tensor: dst)) { |
| 8532 | std::array<uint32_t, 3> elements; |
| 8533 | if (op == GGML_OP_CONV_2D) elements = ggml_vk_get_conv_elements(dst); |
| 8534 | else if (op == GGML_OP_CONV_TRANSPOSE_2D) elements = ggml_vk_get_conv_transpose_2d_elements(dst); |
| 8535 | vk_conv_shapes shape; |
| 8536 | |
| 8537 | uint32_t tiles[CONV_SHAPE_COUNT]; |
| 8538 | for (uint32_t i = 0; i < CONV_SHAPE_COUNT; ++i) { |
| 8539 | tiles[i] = CEIL_DIV(elements[0], ctx->device->pipeline_conv2d_f32[i]->wg_denoms[0]) * CEIL_DIV(elements[1], ctx->device->pipeline_conv2d_f32[i]->wg_denoms[1]); |
| 8540 | } |
| 8541 | |
| 8542 | // We can't query number of shader cores on Intel, use 32 as a placeholder |
| 8543 | // so small convolutions will still choose a smaller tile. |
| 8544 | const uint32_t shader_core_count = ctx->device->shader_core_count > 0 ? ctx->device->shader_core_count : 32; |
| 8545 | |
| 8546 | if (elements[0] > 64 && tiles[CONV_SHAPE_128x128] >= shader_core_count * 2) { |
| 8547 | shape = CONV_SHAPE_128x128; |
| 8548 | } else if (elements[0] <= 32 && tiles[CONV_SHAPE_32x256] >= shader_core_count * 2) { |
| 8549 | shape = CONV_SHAPE_32x256; |
| 8550 | } else { |
| 8551 | shape = CONV_SHAPE_64x32; |
| 8552 | } |
| 8553 | |
| 8554 | if (op == GGML_OP_CONV_2D) { |
| 8555 | if (src0->type == GGML_TYPE_F32) { |
| 8556 | return ctx->device->pipeline_conv2d_f32[shape]; |
| 8557 | } else if (src0->type == GGML_TYPE_F16) { |
| 8558 | return ctx->device->pipeline_conv2d_f16_f32[shape]; |
| 8559 | } |
| 8560 | } else if (op == GGML_OP_CONV_TRANSPOSE_2D) { |
| 8561 | if (src0->type == GGML_TYPE_F32) { |
| 8562 | return ctx->device->pipeline_conv_transpose_2d_f32[shape]; |
| 8563 | } else if (src0->type == GGML_TYPE_F16) { |
| 8564 | return ctx->device->pipeline_conv_transpose_2d_f16_f32[shape]; |
| 8565 | } |
| 8566 | } |
| 8567 | } |
| 8568 | return nullptr; |
| 8569 | case GGML_OP_CONV_2D_DW: |
| 8570 | if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { |
| 8571 | if (ggml_is_contiguous(tensor: src1)) { |
| 8572 | return ctx->device->pipeline_conv2d_dw_whcn_f32; |
| 8573 | } else if (ggml_is_contiguous_channels(tensor: src1)) { |
| 8574 | return ctx->device->pipeline_conv2d_dw_cwhn_f32; |
| 8575 | } |
| 8576 | } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) { |
| 8577 | if (ggml_is_contiguous(tensor: src1)) { |
| 8578 | return ctx->device->pipeline_conv2d_dw_whcn_f16_f32; |
| 8579 | } else if (ggml_is_contiguous_channels(tensor: src1)) { |
| 8580 | return ctx->device->pipeline_conv2d_dw_cwhn_f16_f32; |
| 8581 | } |
| 8582 | } |
| 8583 | return nullptr; |
| 8584 | default: |
| 8585 | return nullptr; |
| 8586 | } |
| 8587 | |
| 8588 | GGML_UNUSED(src2); |
| 8589 | } |
| 8590 | |
| 8591 | static bool ggml_vk_op_supports_incontiguous(ggml_op op) { |
| 8592 | switch (op) { |
| 8593 | case GGML_OP_CPY: |
| 8594 | case GGML_OP_GET_ROWS: |
| 8595 | case GGML_OP_ADD: |
| 8596 | case GGML_OP_SUB: |
| 8597 | case GGML_OP_MUL: |
| 8598 | case GGML_OP_DIV: |
| 8599 | case GGML_OP_ADD_ID: |
| 8600 | case GGML_OP_CONCAT: |
| 8601 | case GGML_OP_UPSCALE: |
| 8602 | case GGML_OP_SQR: |
| 8603 | case GGML_OP_SQRT: |
| 8604 | case GGML_OP_SIN: |
| 8605 | case GGML_OP_COS: |
| 8606 | case GGML_OP_CLAMP: |
| 8607 | case GGML_OP_PAD: |
| 8608 | case GGML_OP_REPEAT: |
| 8609 | case GGML_OP_REPEAT_BACK: |
| 8610 | case GGML_OP_ROPE: |
| 8611 | case GGML_OP_RMS_NORM: |
| 8612 | case GGML_OP_CONV_2D_DW: |
| 8613 | case GGML_OP_IM2COL: |
| 8614 | case GGML_OP_IM2COL_3D: |
| 8615 | case GGML_OP_SET_ROWS: |
| 8616 | case GGML_OP_SUM: |
| 8617 | case GGML_OP_SUM_ROWS: |
| 8618 | case GGML_OP_MEAN: |
| 8619 | return true; |
| 8620 | default: |
| 8621 | return false; |
| 8622 | } |
| 8623 | } |
| 8624 | |
| 8625 | static uint32_t get_misalign_bytes(const ggml_backend_vk_context * ctx, const ggml_tensor * t) |
| 8626 | { |
| 8627 | return ((vk_tensor_offset(tensor: t) + t->view_offs) & (ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1));; |
| 8628 | } |
| 8629 | |
| 8630 | template <typename T> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, T &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) { |
| 8631 | GGML_UNUSED(p); |
| 8632 | GGML_UNUSED(src0); |
| 8633 | GGML_UNUSED(src1); |
| 8634 | GGML_UNUSED(src2); |
| 8635 | GGML_UNUSED(src3); |
| 8636 | GGML_UNUSED(dst); |
| 8637 | static_assert(!std::is_const<T>::value, "unexpected type" ); |
| 8638 | GGML_ASSERT(!src0 || get_misalign_bytes(ctx, src0) == 0); |
| 8639 | GGML_ASSERT(!src1 || get_misalign_bytes(ctx, src1) == 0); |
| 8640 | GGML_ASSERT(!src2 || get_misalign_bytes(ctx, src2) == 0); |
| 8641 | GGML_ASSERT(!src3 || get_misalign_bytes(ctx, src3) == 0); |
| 8642 | GGML_ASSERT(!dst || get_misalign_bytes(ctx, dst) == 0); |
| 8643 | } |
| 8644 | |
| 8645 | template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_unary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) { |
| 8646 | const uint32_t a_offset = get_misalign_bytes(ctx, t: src0) / ggml_type_size(type: src0->type); |
| 8647 | const uint32_t d_offset = get_misalign_bytes(ctx, t: dst) / ggml_type_size(type: dst->type); |
| 8648 | |
| 8649 | p.misalign_offsets = (a_offset << 16) | d_offset; |
| 8650 | |
| 8651 | GGML_UNUSED(src1); |
| 8652 | GGML_UNUSED(src2); |
| 8653 | GGML_UNUSED(src3); |
| 8654 | } |
| 8655 | |
| 8656 | template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_sum_rows_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) { |
| 8657 | const uint32_t a_offset = get_misalign_bytes(ctx, t: src0) / ggml_type_size(type: src0->type); |
| 8658 | const uint32_t d_offset = get_misalign_bytes(ctx, t: dst) / ggml_type_size(type: dst->type); |
| 8659 | |
| 8660 | p.misalign_offsets = (a_offset << 16) | d_offset; |
| 8661 | |
| 8662 | GGML_UNUSED(src1); |
| 8663 | GGML_UNUSED(src2); |
| 8664 | GGML_UNUSED(src3); |
| 8665 | } |
| 8666 | |
| 8667 | template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_pad_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) { |
| 8668 | const uint32_t a_offset = get_misalign_bytes(ctx, t: src0) / ggml_type_size(type: src0->type); |
| 8669 | const uint32_t d_offset = get_misalign_bytes(ctx, t: dst) / ggml_type_size(type: dst->type); |
| 8670 | |
| 8671 | p.misalign_offsets = (a_offset << 16) | d_offset; |
| 8672 | |
| 8673 | GGML_UNUSED(src1); |
| 8674 | GGML_UNUSED(src2); |
| 8675 | GGML_UNUSED(src3); |
| 8676 | } |
| 8677 | |
| 8678 | template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_im2col_3d_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) { |
| 8679 | const uint32_t a_offset = get_misalign_bytes(ctx, t: src1) / ggml_type_size(type: src1->type); |
| 8680 | const uint32_t d_offset = get_misalign_bytes(ctx, t: dst) / ggml_type_size(type: dst->type); |
| 8681 | |
| 8682 | p.misalign_offsets = (a_offset << 16) | d_offset; |
| 8683 | |
| 8684 | GGML_UNUSED(src0); |
| 8685 | GGML_UNUSED(src2); |
| 8686 | GGML_UNUSED(src3); |
| 8687 | } |
| 8688 | |
| 8689 | template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_binary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) { |
| 8690 | const uint32_t a_offset = get_misalign_bytes(ctx, t: src0) / ggml_type_size(type: src0->type); |
| 8691 | const uint32_t b_offset = get_misalign_bytes(ctx, t: src1) / ggml_type_size(type: src1->type); |
| 8692 | const uint32_t d_offset = get_misalign_bytes(ctx, t: dst) / ggml_type_size(type: dst->type); |
| 8693 | |
| 8694 | GGML_ASSERT(dst->op != GGML_OP_GET_ROWS || (a_offset == 0 && b_offset == 0 && d_offset == 0)); |
| 8695 | |
| 8696 | p.misalign_offsets = (a_offset << 16) | (b_offset << 8) | d_offset; |
| 8697 | |
| 8698 | GGML_UNUSED(src2); |
| 8699 | GGML_UNUSED(src3); |
| 8700 | } |
| 8701 | |
| 8702 | template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_upscale_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) { |
| 8703 | const uint32_t a_offset = get_misalign_bytes(ctx, t: src0) / ggml_type_size(type: src0->type); |
| 8704 | const uint32_t d_offset = get_misalign_bytes(ctx, t: dst) / ggml_type_size(type: dst->type); |
| 8705 | |
| 8706 | p.a_offset = a_offset; |
| 8707 | p.d_offset = d_offset; |
| 8708 | |
| 8709 | GGML_UNUSED(src1); |
| 8710 | GGML_UNUSED(src2); |
| 8711 | GGML_UNUSED(src3); |
| 8712 | } |
| 8713 | |
| 8714 | template<typename PC> |
| 8715 | static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst, ggml_op op, PC&& pc) { |
| 8716 | VK_LOG_DEBUG("ggml_vk_op_f32((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; |
| 8717 | if (src1 != nullptr) { |
| 8718 | std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; |
| 8719 | } |
| 8720 | if (src2 != nullptr) { |
| 8721 | std::cerr << "), (" << src2 << ", name=" << src2->name << ", type=" << src2->type << ", ne0=" << src2->ne[0] << ", ne1=" << src2->ne[1] << ", ne2=" << src2->ne[2] << ", ne3=" << src2->ne[3] << ", nb0=" << src2->nb[0] << ", nb1=" << src2->nb[1] << ", nb2=" << src2->nb[2] << ", nb3=" << src2->nb[3]; |
| 8722 | } |
| 8723 | if (src3 != nullptr) { |
| 8724 | std::cerr << "), (" << src3 << ", name=" << src3->name << ", type=" << src3->type << ", ne0=" << src3->ne[0] << ", ne1=" << src3->ne[1] << ", ne2=" << src3->ne[2] << ", ne3=" << src3->ne[3] << ", nb0=" << src3->nb[0] << ", nb1=" << src3->nb[1] << ", nb2=" << src3->nb[2] << ", nb3=" << src3->nb[3]; |
| 8725 | } |
| 8726 | std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; |
| 8727 | std::cerr << "), " << ggml_op_name(op) << ")" ); |
| 8728 | GGML_ASSERT(op == GGML_OP_GET_ROWS || op == GGML_OP_CPY || (!ggml_is_quantized(src0->type) && (src1 == nullptr || !ggml_is_quantized(src1->type)))); // NOLINT |
| 8729 | GGML_ASSERT(ggml_vk_op_supports_incontiguous(op) || ggml_vk_dim01_contiguous(src0)); // NOLINT |
| 8730 | GGML_ASSERT(dst->buffer != nullptr); |
| 8731 | const uint64_t ne00 = src0->ne[0]; |
| 8732 | const uint64_t ne01 = src0->ne[1]; |
| 8733 | const uint64_t ne02 = src0->ne[2]; |
| 8734 | const uint64_t ne03 = src0->ne[3]; |
| 8735 | |
| 8736 | const bool use_src1 = src1 != nullptr; |
| 8737 | const uint64_t ne10 = use_src1 ? src1->ne[0] : 0; |
| 8738 | const uint64_t ne11 = use_src1 ? src1->ne[1] : 0; |
| 8739 | const uint64_t ne12 = use_src1 ? src1->ne[2] : 0; |
| 8740 | const uint64_t ne13 = use_src1 ? src1->ne[3] : 0; |
| 8741 | |
| 8742 | const bool use_src2 = src2 != nullptr; |
| 8743 | const bool use_src3 = src3 != nullptr; |
| 8744 | |
| 8745 | init_pushconst_fastdiv(pc); |
| 8746 | |
| 8747 | vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, dst, op); |
| 8748 | |
| 8749 | if (pipeline == nullptr) { |
| 8750 | std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(op) << " for " << ggml_type_name(type: src0->type); |
| 8751 | if (src1 != nullptr) { |
| 8752 | std::cerr << " and " << ggml_type_name(type: src1->type); |
| 8753 | } |
| 8754 | std::cerr << " to " << ggml_type_name(type: dst->type) << std::endl; |
| 8755 | GGML_ABORT("fatal error" ); |
| 8756 | } |
| 8757 | |
| 8758 | ggml_pipeline_request_descriptor_sets(ctx, pipeline, n: 1); |
| 8759 | |
| 8760 | const bool op_supports_incontiguous = ggml_vk_op_supports_incontiguous(op); |
| 8761 | |
| 8762 | vk_subbuffer src0_buf = ggml_vk_tensor_subbuffer(ctx, tensor: src0, allow_misalign: op_supports_incontiguous); |
| 8763 | vk_subbuffer src1_buf = use_src1 ? ggml_vk_tensor_subbuffer(ctx, tensor: src1, allow_misalign: op_supports_incontiguous) : vk_subbuffer{}; |
| 8764 | vk_subbuffer src2_buf = use_src2 ? ggml_vk_tensor_subbuffer(ctx, tensor: src2, allow_misalign: op_supports_incontiguous) : vk_subbuffer{}; |
| 8765 | vk_subbuffer src3_buf = use_src3 ? ggml_vk_tensor_subbuffer(ctx, tensor: src3, allow_misalign: op_supports_incontiguous) : vk_subbuffer{}; |
| 8766 | vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, tensor: dst, allow_misalign: op_supports_incontiguous); |
| 8767 | |
| 8768 | // Compute misalignment offset for descriptors and store it in in push constants. |
| 8769 | init_pushconst_tensor_offsets(ctx, pc, src0, src1, src2, src3, dst); |
| 8770 | |
| 8771 | std::array<uint32_t, 3> elements; |
| 8772 | |
| 8773 | // Single call if dimension 2 is contiguous |
| 8774 | GGML_ASSERT(op_supports_incontiguous || (ggml_is_contiguous(src0) && (src1 == nullptr || ggml_is_contiguous(src1)))); |
| 8775 | |
| 8776 | switch (op) { |
| 8777 | case GGML_OP_NORM: |
| 8778 | case GGML_OP_RMS_NORM_BACK: |
| 8779 | case GGML_OP_L2_NORM: |
| 8780 | case GGML_OP_SOFT_MAX: |
| 8781 | case GGML_OP_SOFT_MAX_BACK: |
| 8782 | case GGML_OP_SUM_ROWS: |
| 8783 | case GGML_OP_MEAN: |
| 8784 | case GGML_OP_ARGMAX: |
| 8785 | { |
| 8786 | const uint32_t nr = ggml_nrows(tensor: src0); |
| 8787 | if (nr > 262144) { |
| 8788 | elements = { 512, 512, CEIL_DIV(nr, 262144) }; |
| 8789 | } else if (nr > 512) { |
| 8790 | elements = { 512, CEIL_DIV(nr, 512), 1 }; |
| 8791 | } else { |
| 8792 | elements = { nr, 1, 1 }; |
| 8793 | } |
| 8794 | } break; |
| 8795 | case GGML_OP_RMS_NORM: |
| 8796 | if (ctx->do_add_rms_partials) { |
| 8797 | // Run one element per thread, 128 threads per workgroup |
| 8798 | elements = { (uint32_t)CEIL_DIV(ne00, 128), 1, 1 }; |
| 8799 | } else { |
| 8800 | elements = { (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne03 }; |
| 8801 | } |
| 8802 | break; |
| 8803 | |
| 8804 | case GGML_OP_SUM: |
| 8805 | // We use GGML_OP_SUM_ROWS with 1 row. |
| 8806 | elements = { 1, 1, 1 }; |
| 8807 | break; |
| 8808 | case GGML_OP_GROUP_NORM: |
| 8809 | { |
| 8810 | const uint32_t num_groups = dst->op_params[0]; |
| 8811 | elements = { num_groups * (uint32_t)src0->ne[3], 1, 1 }; |
| 8812 | } break; |
| 8813 | case GGML_OP_DIAG_MASK_INF: |
| 8814 | case GGML_OP_ROPE: |
| 8815 | case GGML_OP_ROPE_BACK: |
| 8816 | elements = { (uint32_t)ggml_nrows(tensor: src0), (uint32_t)ne00, 1 }; |
| 8817 | break; |
| 8818 | case GGML_OP_GET_ROWS: |
| 8819 | elements = { (uint32_t)ne00, (uint32_t)ne10, (uint32_t)(ne11 * ne12) }; |
| 8820 | elements[1] = std::min(a: elements[1], b: ctx->device->properties.limits.maxComputeWorkGroupCount[1]); |
| 8821 | elements[2] = std::min(a: elements[2], b: ctx->device->properties.limits.maxComputeWorkGroupCount[2]); |
| 8822 | break; |
| 8823 | case GGML_OP_ARGSORT: |
| 8824 | elements = { (uint32_t)ne00, (uint32_t)ggml_nrows(tensor: src0), 1 }; |
| 8825 | elements[1] = std::min(a: elements[1], b: ctx->device->properties.limits.maxComputeWorkGroupCount[1]); |
| 8826 | break; |
| 8827 | case GGML_OP_IM2COL: |
| 8828 | { |
| 8829 | const bool is_2D = dst->op_params[6] == 1; |
| 8830 | |
| 8831 | const uint32_t IC = src1->ne[is_2D ? 2 : 1]; |
| 8832 | |
| 8833 | const uint32_t KH = is_2D ? src0->ne[1] : 1; |
| 8834 | const uint32_t KW = src0->ne[0]; |
| 8835 | |
| 8836 | const uint32_t OH = is_2D ? dst->ne[2] : 1; |
| 8837 | const uint32_t OW = dst->ne[1]; |
| 8838 | |
| 8839 | const uint32_t batch = src1->ne[is_2D ? 3 : 2]; |
| 8840 | |
| 8841 | elements = { OW * KW * KH, OH, batch * IC }; |
| 8842 | } break; |
| 8843 | case GGML_OP_IM2COL_3D: |
| 8844 | { |
| 8845 | const uint32_t IC = ((const uint32_t *)(dst->op_params))[9]; |
| 8846 | |
| 8847 | const uint32_t N = ne13 / IC; |
| 8848 | |
| 8849 | const uint32_t KD = ne02; |
| 8850 | const uint32_t KH = ne01; |
| 8851 | const uint32_t KW = ne00; |
| 8852 | |
| 8853 | const uint32_t OD = dst->ne[3] / N; |
| 8854 | const uint32_t OH = dst->ne[2]; |
| 8855 | const uint32_t OW = dst->ne[1]; |
| 8856 | |
| 8857 | const uint32_t IC_KD_KH_KW = IC*KD*KH*KW; |
| 8858 | const uint32_t N_OD_OH = N*OD*OH; |
| 8859 | |
| 8860 | elements = { IC_KD_KH_KW, OW, N_OD_OH }; |
| 8861 | elements[2] = std::min(a: elements[2], b: ctx->device->properties.limits.maxComputeWorkGroupCount[2]); |
| 8862 | } break; |
| 8863 | case GGML_OP_TIMESTEP_EMBEDDING: |
| 8864 | { |
| 8865 | const uint32_t dim = dst->op_params[0]; |
| 8866 | uint32_t half_ceil = (dim + 1) / 2; |
| 8867 | elements = { half_ceil, (uint32_t)src0->ne[0], 1 }; |
| 8868 | } break; |
| 8869 | case GGML_OP_CONV_TRANSPOSE_1D: |
| 8870 | { |
| 8871 | elements = {uint32_t(src0->ne[1]), 1, 1}; // parallelize in {Cout, 1, 1} |
| 8872 | } break; |
| 8873 | case GGML_OP_POOL_2D: |
| 8874 | { |
| 8875 | const uint32_t N = dst->ne[3]; |
| 8876 | const uint32_t OC = dst->ne[2]; |
| 8877 | const uint32_t OH = dst->ne[1]; |
| 8878 | const uint32_t OW = dst->ne[0]; |
| 8879 | elements = { N * OC * OH * OW, 1, 1}; |
| 8880 | } break; |
| 8881 | case GGML_OP_CONV_2D: |
| 8882 | { |
| 8883 | elements = ggml_vk_get_conv_elements(dst); |
| 8884 | } break; |
| 8885 | case GGML_OP_CONV_TRANSPOSE_2D: |
| 8886 | { |
| 8887 | elements = ggml_vk_get_conv_transpose_2d_elements(dst); |
| 8888 | } break; |
| 8889 | case GGML_OP_ADD: |
| 8890 | case GGML_OP_SUB: |
| 8891 | case GGML_OP_DIV: |
| 8892 | case GGML_OP_MUL: |
| 8893 | case GGML_OP_SCALE: |
| 8894 | case GGML_OP_SQR: |
| 8895 | case GGML_OP_SQRT: |
| 8896 | case GGML_OP_SIN: |
| 8897 | case GGML_OP_COS: |
| 8898 | case GGML_OP_CLAMP: |
| 8899 | case GGML_OP_PAD: |
| 8900 | case GGML_OP_ROLL: |
| 8901 | case GGML_OP_REPEAT: |
| 8902 | case GGML_OP_REPEAT_BACK: |
| 8903 | case GGML_OP_CPY: |
| 8904 | case GGML_OP_CONCAT: |
| 8905 | case GGML_OP_UPSCALE: |
| 8906 | case GGML_OP_UNARY: |
| 8907 | case GGML_OP_GLU: |
| 8908 | case GGML_OP_CONV_2D_DW: |
| 8909 | { |
| 8910 | uint32_t ne = ggml_nelements(tensor: dst); |
| 8911 | if (op == GGML_OP_CPY && ggml_is_quantized(type: src0->type) && ggml_is_quantized(type: dst->type)) { |
| 8912 | // Convert from number of logical elements to 2- or 4-byte units. |
| 8913 | ne /= ggml_blck_size(type: src0->type); |
| 8914 | if ((ggml_type_size(type: src0->type) % 4) == 0) { |
| 8915 | ne *= ggml_type_size(type: src0->type) / 4; |
| 8916 | } else { |
| 8917 | ne *= ggml_type_size(type: src0->type) / 2; |
| 8918 | } |
| 8919 | } |
| 8920 | // copy_to_quant has block size of 32, and each thread does QUANT_K elements. |
| 8921 | // Splitting into 512x512xZ wouldn't work well since each workgroup does 1024 elements. |
| 8922 | // So divide by block size here before splitting into 512x512 groups. |
| 8923 | if (op == GGML_OP_CPY && !ggml_is_quantized(type: src0->type) && ggml_is_quantized(type: dst->type)) { |
| 8924 | ne = CEIL_DIV(ne, ggml_blck_size(dst->type)); |
| 8925 | } |
| 8926 | if (ne > 262144) { |
| 8927 | elements = { 512, 512, CEIL_DIV(ne, 262144) }; |
| 8928 | } else if (ne > 512) { |
| 8929 | elements = { 512, CEIL_DIV(ne, 512), 1 }; |
| 8930 | } else { |
| 8931 | elements = { ne, 1, 1 }; |
| 8932 | } |
| 8933 | } break; |
| 8934 | case GGML_OP_ADD_ID: |
| 8935 | { |
| 8936 | elements = { (uint32_t)ne01, (uint32_t)ne02, 1 }; |
| 8937 | } break; |
| 8938 | case GGML_OP_SET_ROWS: |
| 8939 | { |
| 8940 | uint32_t ne = ggml_nelements(tensor: src0); |
| 8941 | if (ggml_is_quantized(type: dst->type)) { |
| 8942 | // quants run 32 threads each doing QUANT_K elements |
| 8943 | ne = CEIL_DIV(ne, 32 * ggml_blck_size(dst->type)); |
| 8944 | } else { |
| 8945 | // scalar types do one element per thread, running 512 threads |
| 8946 | ne = CEIL_DIV(ne, 512); |
| 8947 | } |
| 8948 | if (ne > 262144) { |
| 8949 | elements = { 512, 512, CEIL_DIV(ne, 262144) }; |
| 8950 | } else if (ne > 512) { |
| 8951 | elements = { 512, CEIL_DIV(ne, 512), 1 }; |
| 8952 | } else { |
| 8953 | elements = { ne, 1, 1 }; |
| 8954 | } |
| 8955 | } |
| 8956 | break; |
| 8957 | case GGML_OP_SSM_CONV: |
| 8958 | { |
| 8959 | const uint32_t nr = src0->ne[1]; |
| 8960 | const uint32_t n_t = dst->ne[1]; |
| 8961 | const uint32_t n_s = dst->ne[2]; |
| 8962 | elements = { nr, n_t, n_s }; |
| 8963 | } |
| 8964 | break; |
| 8965 | default: |
| 8966 | elements = { (uint32_t)ggml_nelements(tensor: src0), 1, 1 }; |
| 8967 | break; |
| 8968 | } |
| 8969 | |
| 8970 | if (op == GGML_OP_ADD || op == GGML_OP_RMS_NORM) { |
| 8971 | vk_subbuffer a_buf = src0_buf; |
| 8972 | if (ctx->do_add_rms_partials) { |
| 8973 | a_buf = ggml_vk_subbuffer(ctx, buf: ctx->prealloc_add_rms_partials, offset: ctx->prealloc_size_add_rms_partials_offset); |
| 8974 | } |
| 8975 | ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, |
| 8976 | { src0_buf, src1_buf, dst_buf, a_buf }, pc, elements); |
| 8977 | } else if (op == GGML_OP_GLU) { |
| 8978 | // Empty src1 is possible in glu, but the shader needs a buffer |
| 8979 | vk_subbuffer subbuf1 = use_src1 ? src1_buf : src0_buf; |
| 8980 | ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, subbuf1, dst_buf }, pc, elements); |
| 8981 | } else if (op == GGML_OP_SOFT_MAX) { |
| 8982 | // Empty src1 and src2 is possible in soft_max, but the shader needs a buffer |
| 8983 | vk_subbuffer subbuf1 = use_src1 ? src1_buf : src0_buf; |
| 8984 | vk_subbuffer subbuf2 = use_src2 ? src2_buf : src0_buf; |
| 8985 | ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, subbuf1, subbuf2, dst_buf }, pc, elements); |
| 8986 | } else if (op == GGML_OP_ROPE || op == GGML_OP_ROPE_BACK) { |
| 8987 | // Empty src2 and src3 is possible in rope, but the shader needs a buffer |
| 8988 | vk_subbuffer subbuf2 = use_src2 ? src2_buf : src0_buf; |
| 8989 | vk_subbuffer subbuf3 = use_src3 ? src3_buf : src0_buf; |
| 8990 | ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, src1_buf, subbuf2, dst_buf, subbuf3 }, pc, elements); |
| 8991 | } else if (op == GGML_OP_IM2COL || op == GGML_OP_IM2COL_3D) { |
| 8992 | if (ctx->device->shader_int64 && ctx->device->buffer_device_address) { |
| 8993 | // buffer device address path doesn't use dst buffer |
| 8994 | dst_buf.size = 1; |
| 8995 | } |
| 8996 | // im2col uses only src1 and dst buffers |
| 8997 | ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src1_buf, dst_buf }, pc, elements); |
| 8998 | } else if (op == GGML_OP_COUNT_EQUAL) { |
| 8999 | // count_equal assumes that destination buffer is initialized with zeroes |
| 9000 | ggml_vk_buffer_memset_async(ctx&: subctx, dst&: dst_buf.buffer, offset: dst_buf.offset, c: 0, size: dst_buf.size); |
| 9001 | ggml_vk_sync_buffers(ctx, subctx); |
| 9002 | ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, src1_buf, dst_buf }, pc, elements); |
| 9003 | } else if (op == GGML_OP_OPT_STEP_SGD) { |
| 9004 | // OPT_STEP_SGD works on src0, it does not need dst |
| 9005 | ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, src1_buf, src2_buf }, pc, elements); |
| 9006 | } else if (use_src3) { |
| 9007 | ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, src1_buf, src2_buf, src3_buf, dst_buf }, pc, elements); |
| 9008 | } else if (use_src2) { |
| 9009 | ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, src1_buf, src2_buf, dst_buf }, pc, elements); |
| 9010 | } else if (use_src1) { |
| 9011 | ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, src1_buf, dst_buf }, pc, elements); |
| 9012 | } else { |
| 9013 | ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, dst_buf }, pc, elements); |
| 9014 | } |
| 9015 | } |
| 9016 | |
| 9017 | static void ggml_vk_get_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { |
| 9018 | const uint32_t src0_type_size = ggml_type_size(type: src0->type); |
| 9019 | const uint32_t src1_type_size = ggml_type_size(type: src1->type); |
| 9020 | const uint32_t dst_type_size = ggml_type_size(type: dst->type); |
| 9021 | |
| 9022 | ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, src2: nullptr, src3: nullptr, dst, op: GGML_OP_GET_ROWS, pc: { |
| 9023 | .ne: (uint32_t)ggml_nelements(tensor: src0), |
| 9024 | .ne00: (uint32_t)src0->ne[0], .ne01: (uint32_t)src0->ne[1], .ne02: (uint32_t)src0->ne[2],.ne03: (uint32_t)src0->ne[3], .nb00: (uint32_t)src0->nb[0] / src0_type_size, .nb01: (uint32_t)src0->nb[1] / src0_type_size, .nb02: (uint32_t)src0->nb[2] / src0_type_size, .nb03: (uint32_t)src0->nb[3] / src0_type_size, |
| 9025 | .ne10: (uint32_t)src1->ne[0], .ne11: (uint32_t)src1->ne[1], .ne12: (uint32_t)src1->ne[2],.ne13: (uint32_t)src1->ne[3], .nb10: (uint32_t)src1->nb[0] / src1_type_size, .nb11: (uint32_t)src1->nb[1] / src1_type_size, .nb12: (uint32_t)src1->nb[2] / src1_type_size, .nb13: (uint32_t)src1->nb[3] / src1_type_size, |
| 9026 | .ne20: (uint32_t) dst->ne[0], .ne21: (uint32_t) dst->ne[1], .ne22: (uint32_t) dst->ne[2],.ne23: (uint32_t) dst->ne[3], .nb20: (uint32_t) dst->nb[0] / dst_type_size, .nb21: (uint32_t) dst->nb[1] / dst_type_size, .nb22: (uint32_t) dst->nb[2] / dst_type_size, .nb23: (uint32_t) dst->nb[3] / dst_type_size, |
| 9027 | .misalign_offsets: 0, |
| 9028 | .param1: 0.0f, .param2: 0.0f, .param3: 0, |
| 9029 | }); |
| 9030 | } |
| 9031 | |
| 9032 | static void ggml_vk_acc(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { |
| 9033 | const uint32_t src0_type_size = ggml_type_size(type: src0->type); |
| 9034 | const uint32_t src1_type_size = ggml_type_size(type: src1->type); |
| 9035 | const uint32_t dst_type_size = ggml_type_size(type: dst->type); |
| 9036 | |
| 9037 | int nb1 = dst->op_params[0] / 4; // 4 bytes of float32 |
| 9038 | int nb2 = dst->op_params[1] / 4; // 4 bytes of float32 |
| 9039 | // int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused |
| 9040 | int offset = dst->op_params[3] / 4; // offset in bytes |
| 9041 | |
| 9042 | ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, src2: nullptr, src3: nullptr, dst, op: GGML_OP_ACC, pc: { |
| 9043 | .ne: (uint32_t)ggml_nelements(tensor: src0), |
| 9044 | .ne00: (uint32_t)src0->ne[0], .ne01: (uint32_t)src0->ne[1], .ne02: (uint32_t)src0->ne[2],.ne03: (uint32_t)src0->ne[3], .nb00: (uint32_t)src0->nb[0] / src0_type_size, .nb01: (uint32_t)nb1, .nb02: (uint32_t)nb2, .nb03: (uint32_t)src0->nb[3] / src0_type_size, |
| 9045 | .ne10: (uint32_t)src1->ne[0], .ne11: (uint32_t)src1->ne[1], .ne12: (uint32_t)src1->ne[2],.ne13: (uint32_t)src1->ne[3], .nb10: (uint32_t)src1->nb[0] / src1_type_size, .nb11: (uint32_t)src1->nb[1] / src1_type_size, .nb12: (uint32_t)src1->nb[2] / src1_type_size, .nb13: (uint32_t)src1->nb[3] / src1_type_size, |
| 9046 | .ne20: (uint32_t) dst->ne[0], .ne21: (uint32_t) dst->ne[1], .ne22: (uint32_t) dst->ne[2],.ne23: (uint32_t) dst->ne[3], .nb20: (uint32_t) dst->nb[0] / dst_type_size, .nb21: (uint32_t)nb1, .nb22: (uint32_t)nb2, .nb23: (uint32_t) dst->nb[3] / dst_type_size, |
| 9047 | .misalign_offsets: 0, |
| 9048 | .param1: 0.0f, .param2: 0.0f, .param3: offset, |
| 9049 | }); |
| 9050 | } |
| 9051 | |
| 9052 | static void ggml_vk_multi_add(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx) { |
| 9053 | const ggml_tensor *first_node = cgraph->nodes[node_idx]; |
| 9054 | const ggml_tensor *dst = cgraph->nodes[node_idx + ctx->num_additional_fused_ops]; |
| 9055 | |
| 9056 | // Make a list of all the tensors used by the op. |
| 9057 | // Last element of the list is the dest tensor. |
| 9058 | const ggml_tensor *tensors[MAX_PARAMETER_COUNT]; |
| 9059 | uint32_t num_srcs = ctx->num_additional_fused_ops + 2; |
| 9060 | uint32_t num_tensors = num_srcs + 1; |
| 9061 | GGML_ASSERT(num_tensors + ctx->do_add_rms_partials <= MAX_PARAMETER_COUNT); |
| 9062 | |
| 9063 | tensors[0] = first_node->src[0]; |
| 9064 | tensors[1] = first_node->src[1]; |
| 9065 | for (int32_t i = 0; i < ctx->num_additional_fused_ops; ++i) { |
| 9066 | // check whether the previous result is src[0] or src[1] |
| 9067 | if (cgraph->nodes[node_idx + i] == cgraph->nodes[node_idx + i + 1]->src[0]) { |
| 9068 | tensors[i+2] = cgraph->nodes[node_idx + i + 1]->src[1]; |
| 9069 | } else { |
| 9070 | tensors[i+2] = cgraph->nodes[node_idx + i + 1]->src[0]; |
| 9071 | } |
| 9072 | } |
| 9073 | tensors[num_srcs] = dst; |
| 9074 | |
| 9075 | vk_op_multi_add_push_constants pc; |
| 9076 | pc.ne20 = (uint32_t)dst->ne[0]; |
| 9077 | pc.ne21 = (uint32_t)dst->ne[1]; |
| 9078 | pc.ne22 = (uint32_t)dst->ne[2]; |
| 9079 | pc.ne23 = (uint32_t)dst->ne[3]; |
| 9080 | |
| 9081 | for (uint32_t i = 0; i < num_tensors; ++i) { |
| 9082 | const ggml_tensor *t = tensors[i]; |
| 9083 | pc.nb[i][0] = (uint32_t)t->nb[0] / sizeof(float); |
| 9084 | pc.nb[i][1] = (uint32_t)t->nb[1] / sizeof(float); |
| 9085 | pc.nb[i][2] = (uint32_t)t->nb[2] / sizeof(float); |
| 9086 | pc.nb[i][3] = (uint32_t)t->nb[3] / sizeof(float); |
| 9087 | } |
| 9088 | pc.rms_partials = ctx->do_add_rms_partials; |
| 9089 | |
| 9090 | vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0: tensors[0], src1: tensors[1], src2: nullptr, dst, op: dst->op); |
| 9091 | |
| 9092 | if (pipeline == nullptr) { |
| 9093 | std::cerr << "ggml_vulkan: Error: Missing multi_add" ; |
| 9094 | GGML_ABORT("fatal error" ); |
| 9095 | } |
| 9096 | |
| 9097 | ggml_pipeline_request_descriptor_sets(ctx, pipeline, n: 1); |
| 9098 | |
| 9099 | ggml_backend_vk_buffer_context * buf_ctx[MAX_PARAMETER_COUNT]; |
| 9100 | vk_buffer buf[MAX_PARAMETER_COUNT]; |
| 9101 | size_t offset[MAX_PARAMETER_COUNT]; |
| 9102 | bool uma[MAX_PARAMETER_COUNT]; |
| 9103 | |
| 9104 | for (uint32_t i = 0; i < num_tensors; ++i) { |
| 9105 | buf_ctx[i] = (ggml_backend_vk_buffer_context *)tensors[i]->buffer->context; |
| 9106 | buf[i] = nullptr; |
| 9107 | offset[i] = 0; |
| 9108 | uma[i] = false; |
| 9109 | |
| 9110 | if (ctx->device->uma) { |
| 9111 | ggml_vk_host_get(device: ctx->device, ptr: tensors[i]->data, buf&: buf[i], buf_offset&: offset[i]); |
| 9112 | uma[i] = buf[i] != nullptr; |
| 9113 | } |
| 9114 | if (!uma[i]) { |
| 9115 | buf[i] = buf_ctx[i]->dev_buffer; |
| 9116 | offset[i] = vk_tensor_offset(tensor: tensors[i]) + tensors[i]->view_offs; |
| 9117 | } |
| 9118 | GGML_ASSERT(buf[i] != nullptr); |
| 9119 | } |
| 9120 | // If any remaining descriptors are unused, just point them at src[0] |
| 9121 | for (uint32_t i = num_tensors; i < MAX_PARAMETER_COUNT; ++i) { |
| 9122 | buf[i] = buf[0]; |
| 9123 | offset[i] = 0; |
| 9124 | } |
| 9125 | if (ctx->do_add_rms_partials) { |
| 9126 | buf[num_tensors] = ctx->prealloc_add_rms_partials; |
| 9127 | offset[num_tensors] = ctx->prealloc_size_add_rms_partials_offset; |
| 9128 | } |
| 9129 | |
| 9130 | std::array<uint32_t, 3> elements; |
| 9131 | |
| 9132 | uint32_t ne = ggml_nelements(tensor: dst); |
| 9133 | if (ne > 262144) { |
| 9134 | elements = { 512, 512, CEIL_DIV(ne, 262144) }; |
| 9135 | } else if (ne > 512) { |
| 9136 | elements = { 512, CEIL_DIV(ne, 512), 1 }; |
| 9137 | } else { |
| 9138 | elements = { ne, 1, 1 }; |
| 9139 | } |
| 9140 | |
| 9141 | static_assert(MAX_PARAMETER_COUNT == 12); |
| 9142 | ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, |
| 9143 | descriptor_buffer_infos: { |
| 9144 | ggml_vk_subbuffer(ctx, buf: buf[0], offset: offset[0]), |
| 9145 | ggml_vk_subbuffer(ctx, buf: buf[1], offset: offset[1]), |
| 9146 | ggml_vk_subbuffer(ctx, buf: buf[2], offset: offset[2]), |
| 9147 | ggml_vk_subbuffer(ctx, buf: buf[3], offset: offset[3]), |
| 9148 | ggml_vk_subbuffer(ctx, buf: buf[4], offset: offset[4]), |
| 9149 | ggml_vk_subbuffer(ctx, buf: buf[5], offset: offset[5]), |
| 9150 | ggml_vk_subbuffer(ctx, buf: buf[6], offset: offset[6]), |
| 9151 | ggml_vk_subbuffer(ctx, buf: buf[7], offset: offset[7]), |
| 9152 | ggml_vk_subbuffer(ctx, buf: buf[8], offset: offset[8]), |
| 9153 | ggml_vk_subbuffer(ctx, buf: buf[9], offset: offset[9]), |
| 9154 | ggml_vk_subbuffer(ctx, buf: buf[10], offset: offset[10]), |
| 9155 | ggml_vk_subbuffer(ctx, buf: buf[11], offset: offset[11]), |
| 9156 | }, push_constants: pc, elements); |
| 9157 | } |
| 9158 | |
| 9159 | static void ggml_vk_add(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { |
| 9160 | const uint32_t src0_type_size = ggml_type_size(type: src0->type); |
| 9161 | const uint32_t src1_type_size = ggml_type_size(type: src1->type); |
| 9162 | const uint32_t dst_type_size = ggml_type_size(type: dst->type); |
| 9163 | |
| 9164 | ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, src2: nullptr, src3: nullptr, dst, op: GGML_OP_ADD, pc: { |
| 9165 | .ne: (uint32_t)ggml_nelements(tensor: src0), |
| 9166 | .ne00: (uint32_t)src0->ne[0], .ne01: (uint32_t)src0->ne[1], .ne02: (uint32_t)src0->ne[2],.ne03: (uint32_t)src0->ne[3], .nb00: (uint32_t)src0->nb[0] / src0_type_size, .nb01: (uint32_t)src0->nb[1] / src0_type_size, .nb02: (uint32_t)src0->nb[2] / src0_type_size, .nb03: (uint32_t)src0->nb[3] / src0_type_size, |
| 9167 | .ne10: (uint32_t)src1->ne[0], .ne11: (uint32_t)src1->ne[1], .ne12: (uint32_t)src1->ne[2],.ne13: (uint32_t)src1->ne[3], .nb10: (uint32_t)src1->nb[0] / src1_type_size, .nb11: (uint32_t)src1->nb[1] / src1_type_size, .nb12: (uint32_t)src1->nb[2] / src1_type_size, .nb13: (uint32_t)src1->nb[3] / src1_type_size, |
| 9168 | .ne20: (uint32_t) dst->ne[0], .ne21: (uint32_t) dst->ne[1], .ne22: (uint32_t) dst->ne[2],.ne23: (uint32_t) dst->ne[3], .nb20: (uint32_t) dst->nb[0] / dst_type_size, .nb21: (uint32_t) dst->nb[1] / dst_type_size, .nb22: (uint32_t) dst->nb[2] / dst_type_size, .nb23: (uint32_t) dst->nb[3] / dst_type_size, |
| 9169 | .misalign_offsets: 0, |
| 9170 | .param1: 0.0f, .param2: 0.0f, .param3: ctx->do_add_rms_partials, |
| 9171 | }); |
| 9172 | } |
| 9173 | |
| 9174 | static void ggml_vk_sub(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { |
| 9175 | const uint32_t src0_type_size = ggml_type_size(type: src0->type); |
| 9176 | const uint32_t src1_type_size = ggml_type_size(type: src1->type); |
| 9177 | const uint32_t dst_type_size = ggml_type_size(type: dst->type); |
| 9178 | |
| 9179 | ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, src2: nullptr, src3: nullptr, dst, op: GGML_OP_SUB, pc: { |
| 9180 | .ne: (uint32_t)ggml_nelements(tensor: src0), |
| 9181 | .ne00: (uint32_t)src0->ne[0], .ne01: (uint32_t)src0->ne[1], .ne02: (uint32_t)src0->ne[2],.ne03: (uint32_t)src0->ne[3], .nb00: (uint32_t)src0->nb[0] / src0_type_size, .nb01: (uint32_t)src0->nb[1] / src0_type_size, .nb02: (uint32_t)src0->nb[2] / src0_type_size, .nb03: (uint32_t)src0->nb[3] / src0_type_size, |
| 9182 | .ne10: (uint32_t)src1->ne[0], .ne11: (uint32_t)src1->ne[1], .ne12: (uint32_t)src1->ne[2],.ne13: (uint32_t)src1->ne[3], .nb10: (uint32_t)src1->nb[0] / src1_type_size, .nb11: (uint32_t)src1->nb[1] / src1_type_size, .nb12: (uint32_t)src1->nb[2] / src1_type_size, .nb13: (uint32_t)src1->nb[3] / src1_type_size, |
| 9183 | .ne20: (uint32_t) dst->ne[0], .ne21: (uint32_t) dst->ne[1], .ne22: (uint32_t) dst->ne[2],.ne23: (uint32_t) dst->ne[3], .nb20: (uint32_t) dst->nb[0] / dst_type_size, .nb21: (uint32_t) dst->nb[1] / dst_type_size, .nb22: (uint32_t) dst->nb[2] / dst_type_size, .nb23: (uint32_t) dst->nb[3] / dst_type_size, |
| 9184 | .misalign_offsets: 0, |
| 9185 | .param1: 0.0f, .param2: 0.0f, .param3: 0, |
| 9186 | }); |
| 9187 | } |
| 9188 | |
| 9189 | static void ggml_vk_mul(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { |
| 9190 | const uint32_t src0_type_size = ggml_type_size(type: src0->type); |
| 9191 | const uint32_t src1_type_size = ggml_type_size(type: src1->type); |
| 9192 | const uint32_t dst_type_size = ggml_type_size(type: dst->type); |
| 9193 | |
| 9194 | ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, src2: nullptr, src3: nullptr, dst, op: GGML_OP_MUL, pc: { |
| 9195 | .ne: (uint32_t)ggml_nelements(tensor: src0), |
| 9196 | .ne00: (uint32_t)src0->ne[0], .ne01: (uint32_t)src0->ne[1], .ne02: (uint32_t)src0->ne[2],.ne03: (uint32_t)src0->ne[3], .nb00: (uint32_t)src0->nb[0] / src0_type_size, .nb01: (uint32_t)src0->nb[1] / src0_type_size, .nb02: (uint32_t)src0->nb[2] / src0_type_size, .nb03: (uint32_t)src0->nb[3] / src0_type_size, |
| 9197 | .ne10: (uint32_t)src1->ne[0], .ne11: (uint32_t)src1->ne[1], .ne12: (uint32_t)src1->ne[2],.ne13: (uint32_t)src1->ne[3], .nb10: (uint32_t)src1->nb[0] / src1_type_size, .nb11: (uint32_t)src1->nb[1] / src1_type_size, .nb12: (uint32_t)src1->nb[2] / src1_type_size, .nb13: (uint32_t)src1->nb[3] / src1_type_size, |
| 9198 | .ne20: (uint32_t) dst->ne[0], .ne21: (uint32_t) dst->ne[1], .ne22: (uint32_t) dst->ne[2],.ne23: (uint32_t) dst->ne[3], .nb20: (uint32_t) dst->nb[0] / dst_type_size, .nb21: (uint32_t) dst->nb[1] / dst_type_size, .nb22: (uint32_t) dst->nb[2] / dst_type_size, .nb23: (uint32_t) dst->nb[3] / dst_type_size, |
| 9199 | .misalign_offsets: 0, |
| 9200 | .param1: 0.0f, .param2: 0.0f, .param3: 0, |
| 9201 | }); |
| 9202 | } |
| 9203 | |
| 9204 | static void ggml_vk_div(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { |
| 9205 | const uint32_t src0_type_size = ggml_type_size(type: src0->type); |
| 9206 | const uint32_t src1_type_size = ggml_type_size(type: src1->type); |
| 9207 | const uint32_t dst_type_size = ggml_type_size(type: dst->type); |
| 9208 | |
| 9209 | ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, src2: nullptr, src3: nullptr, dst, op: GGML_OP_DIV, pc: { |
| 9210 | .ne: (uint32_t)ggml_nelements(tensor: src0), |
| 9211 | .ne00: (uint32_t)src0->ne[0], .ne01: (uint32_t)src0->ne[1], .ne02: (uint32_t)src0->ne[2],.ne03: (uint32_t)src0->ne[3], .nb00: (uint32_t)src0->nb[0] / src0_type_size, .nb01: (uint32_t)src0->nb[1] / src0_type_size, .nb02: (uint32_t)src0->nb[2] / src0_type_size, .nb03: (uint32_t)src0->nb[3] / src0_type_size, |
| 9212 | .ne10: (uint32_t)src1->ne[0], .ne11: (uint32_t)src1->ne[1], .ne12: (uint32_t)src1->ne[2],.ne13: (uint32_t)src1->ne[3], .nb10: (uint32_t)src1->nb[0] / src1_type_size, .nb11: (uint32_t)src1->nb[1] / src1_type_size, .nb12: (uint32_t)src1->nb[2] / src1_type_size, .nb13: (uint32_t)src1->nb[3] / src1_type_size, |
| 9213 | .ne20: (uint32_t) dst->ne[0], .ne21: (uint32_t) dst->ne[1], .ne22: (uint32_t) dst->ne[2],.ne23: (uint32_t) dst->ne[3], .nb20: (uint32_t) dst->nb[0] / dst_type_size, .nb21: (uint32_t) dst->nb[1] / dst_type_size, .nb22: (uint32_t) dst->nb[2] / dst_type_size, .nb23: (uint32_t) dst->nb[3] / dst_type_size, |
| 9214 | .misalign_offsets: 0, |
| 9215 | .param1: 0.0f, .param2: 0.0f, .param3: 0, |
| 9216 | }); |
| 9217 | } |
| 9218 | |
| 9219 | static void ggml_vk_add_id(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { |
| 9220 | const uint32_t src0_type_size = ggml_type_size(type: src0->type); |
| 9221 | const uint32_t src1_type_size = ggml_type_size(type: src1->type); |
| 9222 | const uint32_t src2_type_size = ggml_type_size(type: src2->type); |
| 9223 | |
| 9224 | ggml_vk_op_f32<vk_op_add_id_push_constants>(ctx, subctx, src0, src1, src2, src3: nullptr, dst, op: GGML_OP_ADD_ID, pc: { |
| 9225 | .ne0: (uint32_t)dst->ne[0], |
| 9226 | .ne1: (uint32_t)dst->ne[1], |
| 9227 | .s01: (uint32_t)src0->nb[1] / src0_type_size, |
| 9228 | .s02: (uint32_t)src0->nb[2] / src0_type_size, |
| 9229 | .s11: (uint32_t)src1->nb[1] / src1_type_size, |
| 9230 | .s21: (uint32_t)src2->nb[1] / src2_type_size, |
| 9231 | }); |
| 9232 | } |
| 9233 | |
| 9234 | static void ggml_vk_op_f32_wkv(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_rwkv_wkv6_push_constants&& pc, int version) { |
| 9235 | GGML_ASSERT(version == 6 || version == 7); |
| 9236 | int num_srcs = version == 6 ? 6 : 7; |
| 9237 | |
| 9238 | for (int i = 0; i < num_srcs; i++) { |
| 9239 | GGML_ASSERT(!ggml_is_quantized(dst->src[i]->type)); |
| 9240 | } |
| 9241 | |
| 9242 | GGML_ASSERT(dst->buffer != nullptr); |
| 9243 | |
| 9244 | vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0: dst->src[0], src1: dst->src[1], src2: dst->src[2], dst, op: dst->op); |
| 9245 | GGML_ASSERT(pipeline != nullptr); |
| 9246 | |
| 9247 | ggml_pipeline_request_descriptor_sets(ctx, pipeline, n: 1); |
| 9248 | |
| 9249 | vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, tensor: dst); |
| 9250 | vk_subbuffer src_buf[7] = {}; |
| 9251 | for (int i = 0; i < num_srcs; i++) { |
| 9252 | src_buf[i] = ggml_vk_tensor_subbuffer(ctx, tensor: dst->src[i]); |
| 9253 | } |
| 9254 | |
| 9255 | std::array<uint32_t, 3> elements = { |
| 9256 | (uint32_t)(pc.B * pc.H), |
| 9257 | 1, |
| 9258 | 1 |
| 9259 | }; |
| 9260 | |
| 9261 | if (version == 6) { |
| 9262 | ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, |
| 9263 | descriptor_buffer_infos: {src_buf[0], src_buf[1], src_buf[2], src_buf[3], src_buf[4], src_buf[5], dst_buf}, |
| 9264 | push_constants: pc, elements); |
| 9265 | } else if (version == 7) { |
| 9266 | ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, |
| 9267 | descriptor_buffer_infos: {src_buf[0], src_buf[1], src_buf[2], src_buf[3], src_buf[4], src_buf[5], src_buf[6], dst_buf}, |
| 9268 | push_constants: pc, elements); |
| 9269 | } else { |
| 9270 | // shouldn't happen |
| 9271 | GGML_ASSERT(false); |
| 9272 | } |
| 9273 | } |
| 9274 | |
| 9275 | static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) { |
| 9276 | const size_t seq_length = dst->src[0]->ne[2]; |
| 9277 | const size_t n_embed = dst->ne[0]; |
| 9278 | const size_t n_heads = dst->src[0]->ne[1]; |
| 9279 | const size_t n_seqs = dst->src[5]->ne[1]; |
| 9280 | |
| 9281 | ggml_vk_op_f32_wkv( |
| 9282 | ctx, subctx, dst, |
| 9283 | pc: { |
| 9284 | .B: (uint32_t)n_seqs, |
| 9285 | .T: (uint32_t)seq_length, |
| 9286 | .C: (uint32_t)n_embed, |
| 9287 | .H: (uint32_t)n_heads, |
| 9288 | }, |
| 9289 | version: 6 |
| 9290 | ); |
| 9291 | } |
| 9292 | |
| 9293 | static void ggml_vk_rwkv_wkv7(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) { |
| 9294 | const size_t seq_length = dst->src[0]->ne[2]; |
| 9295 | const size_t n_embed = dst->ne[0]; |
| 9296 | const size_t n_heads = dst->src[0]->ne[1]; |
| 9297 | const size_t n_seqs = dst->src[6]->ne[1]; |
| 9298 | |
| 9299 | ggml_vk_op_f32_wkv( |
| 9300 | ctx, subctx, dst, |
| 9301 | pc: { |
| 9302 | .B: (uint32_t)n_seqs, |
| 9303 | .T: (uint32_t)seq_length, |
| 9304 | .C: (uint32_t)n_embed, |
| 9305 | .H: (uint32_t)n_heads, |
| 9306 | }, |
| 9307 | version: 7 |
| 9308 | ); |
| 9309 | } |
| 9310 | |
| 9311 | static void ggml_vk_ssm_scan(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) { |
| 9312 | const ggml_tensor * src0 = dst->src[0]; |
| 9313 | const ggml_tensor * src1 = dst->src[1]; |
| 9314 | const ggml_tensor * src2 = dst->src[2]; |
| 9315 | const ggml_tensor * src3 = dst->src[3]; |
| 9316 | const ggml_tensor * src4 = dst->src[4]; |
| 9317 | const ggml_tensor * src5 = dst->src[5]; |
| 9318 | |
| 9319 | GGML_ASSERT(dst->buffer != nullptr); |
| 9320 | |
| 9321 | const uint32_t head_dim = src0->ne[1]; |
| 9322 | const uint32_t n_head = src1->ne[1]; |
| 9323 | const uint32_t n_group = src4->ne[1]; |
| 9324 | const uint32_t n_tok = src1->ne[2]; |
| 9325 | const uint32_t n_seq = src1->ne[3]; |
| 9326 | |
| 9327 | bool is_mamba2 = (src3->nb[1] == sizeof(float)); |
| 9328 | GGML_ASSERT(is_mamba2); |
| 9329 | |
| 9330 | vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, dst, op: dst->op); |
| 9331 | GGML_ASSERT(pipeline != nullptr); |
| 9332 | |
| 9333 | ggml_pipeline_request_descriptor_sets(ctx, pipeline, n: 1); |
| 9334 | |
| 9335 | const int64_t s_off = ggml_nelements(tensor: src1) * sizeof(float); |
| 9336 | |
| 9337 | const vk_op_ssm_scan_push_constants pc = { |
| 9338 | .nb02: (uint32_t)src0->nb[2], .nb03: (uint32_t)src0->nb[3], |
| 9339 | .nb12: (uint32_t)src1->nb[2], .nb13: (uint32_t)src1->nb[3], |
| 9340 | .nb21: (uint32_t)src2->nb[1], .nb22: (uint32_t)src2->nb[2], |
| 9341 | .nb31: (uint32_t)src3->nb[1], |
| 9342 | .nb42: (uint32_t)src4->nb[2], .nb43: (uint32_t)src4->nb[3], |
| 9343 | .nb52: (uint32_t)src5->nb[2], .nb53: (uint32_t)src5->nb[3], |
| 9344 | .s_off: (uint32_t)s_off, |
| 9345 | .n_head: n_head, .d_head: head_dim, .n_group: n_group, .n_tok: n_tok |
| 9346 | }; |
| 9347 | |
| 9348 | vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, tensor: dst); |
| 9349 | vk_subbuffer src_buf[7] = {}; |
| 9350 | for (int i = 0; i < 7 && dst->src[i] != nullptr; i++) { |
| 9351 | src_buf[i] = ggml_vk_tensor_subbuffer(ctx, tensor: dst->src[i]); |
| 9352 | } |
| 9353 | |
| 9354 | std::array<uint32_t, 3> elements; |
| 9355 | |
| 9356 | const int splitH = 16; |
| 9357 | const uint32_t num_workgroups_x = CEIL_DIV(n_head * head_dim, splitH); |
| 9358 | const uint32_t num_workgroups_y = n_seq; |
| 9359 | elements = { num_workgroups_x, num_workgroups_y, 1 }; |
| 9360 | |
| 9361 | ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, |
| 9362 | descriptor_buffer_infos: {src_buf[0], src_buf[1], src_buf[2], src_buf[3], src_buf[4], src_buf[5], src_buf[6], dst_buf}, |
| 9363 | push_constants: pc, elements); |
| 9364 | } |
| 9365 | |
| 9366 | static void ggml_vk_ssm_conv(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) { |
| 9367 | const ggml_tensor * src0 = dst->src[0]; |
| 9368 | const ggml_tensor * src1 = dst->src[1]; |
| 9369 | |
| 9370 | ggml_vk_op_f32<vk_op_ssm_conv_push_constants>(ctx, subctx, src0, src1, src2: nullptr, src3: nullptr, dst, op: GGML_OP_SSM_CONV, pc: { |
| 9371 | .nb01: (uint32_t)src0->nb[1], .nb02: (uint32_t)src0->nb[2], |
| 9372 | .nb11: (uint32_t)src1->nb[1], |
| 9373 | .dst_nb0: (uint32_t)dst->nb[0], .dst_nb1: (uint32_t)dst->nb[1], .dst_nb2: (uint32_t)dst->nb[2], |
| 9374 | .nc: (uint32_t)src1->ne[0], |
| 9375 | .ncs: (uint32_t)src0->ne[0], |
| 9376 | .nr: (uint32_t)src0->ne[1], |
| 9377 | .n_t: (uint32_t)dst->ne[1], |
| 9378 | .n_s: (uint32_t)dst->ne[2], |
| 9379 | }); |
| 9380 | } |
| 9381 | |
| 9382 | static void ggml_vk_op_f32_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_push_constants&& pc) { |
| 9383 | const ggml_tensor * x = dst->src[0]; |
| 9384 | const ggml_tensor * g = dst->src[1]; |
| 9385 | const ggml_tensor * gm = dst->src[2]; |
| 9386 | const ggml_tensor * gv = dst->src[3]; |
| 9387 | const ggml_tensor * p = dst->src[4]; |
| 9388 | |
| 9389 | GGML_ASSERT(x->type == GGML_TYPE_F32); |
| 9390 | GGML_ASSERT(g->type == GGML_TYPE_F32); |
| 9391 | GGML_ASSERT(gm->type == GGML_TYPE_F32); |
| 9392 | GGML_ASSERT(gv->type == GGML_TYPE_F32); |
| 9393 | GGML_ASSERT(p->type == GGML_TYPE_F32); |
| 9394 | GGML_ASSERT(dst->buffer != nullptr); |
| 9395 | GGML_ASSERT(ggml_is_contiguous(x)); |
| 9396 | GGML_ASSERT(ggml_is_contiguous(g)); |
| 9397 | GGML_ASSERT(ggml_is_contiguous(gm)); |
| 9398 | GGML_ASSERT(ggml_is_contiguous(gv)); |
| 9399 | GGML_ASSERT(ggml_is_contiguous(p)); |
| 9400 | GGML_ASSERT(ggml_are_same_shape(x, g)); |
| 9401 | GGML_ASSERT(ggml_are_same_shape(x, gm)); |
| 9402 | GGML_ASSERT(ggml_are_same_shape(x, gv)); |
| 9403 | GGML_ASSERT(ggml_nelements(p) == 7); |
| 9404 | |
| 9405 | vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0: g, src1: gm, src2: gv, dst, op: GGML_OP_OPT_STEP_ADAMW); |
| 9406 | GGML_ASSERT(pipeline != nullptr); |
| 9407 | |
| 9408 | ggml_pipeline_request_descriptor_sets(ctx, pipeline, n: 1); |
| 9409 | |
| 9410 | vk_subbuffer x_buf = ggml_vk_tensor_subbuffer(ctx, tensor: x); |
| 9411 | vk_subbuffer g_buf = ggml_vk_tensor_subbuffer(ctx, tensor: g); |
| 9412 | vk_subbuffer gm_buf = ggml_vk_tensor_subbuffer(ctx, tensor: gm); |
| 9413 | vk_subbuffer gv_buf = ggml_vk_tensor_subbuffer(ctx, tensor: gv); |
| 9414 | vk_subbuffer p_buf = ggml_vk_tensor_subbuffer(ctx, tensor: p); |
| 9415 | |
| 9416 | std::array<uint32_t, 3> elements = { (uint32_t)ggml_nelements(tensor: x), 1, 1 }; |
| 9417 | |
| 9418 | ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, |
| 9419 | descriptor_buffer_infos: {x_buf, g_buf, gm_buf, gv_buf, p_buf}, |
| 9420 | push_constants: pc, elements); |
| 9421 | } |
| 9422 | |
| 9423 | static void ggml_vk_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) { |
| 9424 | const size_t n = ggml_nelements(tensor: dst->src[0]); |
| 9425 | |
| 9426 | ggml_vk_op_f32_opt_step_adamw( |
| 9427 | ctx, subctx, dst, |
| 9428 | pc: { .KX: (uint32_t)n, .KY: 0, .param1: 0.0f, .param2: 0.0f } |
| 9429 | ); |
| 9430 | } |
| 9431 | |
| 9432 | static void ggml_vk_opt_step_sgd(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { |
| 9433 | const size_t n = ggml_nelements(tensor: dst->src[0]); |
| 9434 | |
| 9435 | ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, src2, src3: nullptr, dst, op: GGML_OP_OPT_STEP_SGD, pc: { .KX: (uint32_t)n, .KY: 0, .param1: 0.0f, .param2: 0.0f }); |
| 9436 | } |
| 9437 | |
| 9438 | static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { |
| 9439 | int * op_params = (int *)dst->op_params; |
| 9440 | |
| 9441 | const uint32_t src0_type_size = ggml_type_size(type: src0->type); |
| 9442 | const uint32_t src1_type_size = ggml_type_size(type: src1->type); |
| 9443 | const uint32_t dst_type_size = ggml_type_size(type: dst->type); |
| 9444 | |
| 9445 | ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, src2: nullptr, src3: nullptr, dst, op: GGML_OP_CONCAT, pc: { |
| 9446 | .ne: (uint32_t)ggml_nelements(tensor: dst), |
| 9447 | .ne00: (uint32_t)src0->ne[0], .ne01: (uint32_t)src0->ne[1], .ne02: (uint32_t)src0->ne[2],.ne03: (uint32_t)src0->ne[3], .nb00: (uint32_t)src0->nb[0] / src0_type_size, .nb01: (uint32_t)src0->nb[1] / src0_type_size, .nb02: (uint32_t)src0->nb[2] / src0_type_size, .nb03: (uint32_t)src0->nb[3] / src0_type_size, |
| 9448 | .ne10: (uint32_t)src1->ne[0], .ne11: (uint32_t)src1->ne[1], .ne12: (uint32_t)src1->ne[2],.ne13: (uint32_t)src1->ne[3], .nb10: (uint32_t)src1->nb[0] / src1_type_size, .nb11: (uint32_t)src1->nb[1] / src1_type_size, .nb12: (uint32_t)src1->nb[2] / src1_type_size, .nb13: (uint32_t)src1->nb[3] / src1_type_size, |
| 9449 | .ne20: (uint32_t) dst->ne[0], .ne21: (uint32_t) dst->ne[1], .ne22: (uint32_t) dst->ne[2],.ne23: (uint32_t) dst->ne[3], .nb20: (uint32_t) dst->nb[0] / dst_type_size, .nb21: (uint32_t) dst->nb[1] / dst_type_size, .nb22: (uint32_t) dst->nb[2] / dst_type_size, .nb23: (uint32_t) dst->nb[3] / dst_type_size, |
| 9450 | .misalign_offsets: 0, |
| 9451 | .param1: 0.0f, .param2: 0.0f, .param3: op_params[0], |
| 9452 | }); |
| 9453 | } |
| 9454 | |
| 9455 | static void ggml_vk_upscale(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { |
| 9456 | const uint32_t src0_type_size = ggml_type_size(type: src0->type); |
| 9457 | const uint32_t mode = (uint32_t)ggml_get_op_params_i32(tensor: dst, i: 0); |
| 9458 | |
| 9459 | GGML_TENSOR_UNARY_OP_LOCALS |
| 9460 | |
| 9461 | float sf0 = (float)ne0 / ne00; |
| 9462 | float sf1 = (float)ne1 / ne01; |
| 9463 | float sf2 = (float)ne2 / ne02; |
| 9464 | float sf3 = (float)ne3 / ne03; |
| 9465 | float pixel_offset = 0.5f; |
| 9466 | |
| 9467 | if (mode & GGML_SCALE_FLAG_ALIGN_CORNERS) { |
| 9468 | sf0 = ne0 > 1 && ne00 > 1 ? (float)(ne0 - 1) / (ne00 - 1) : sf0; |
| 9469 | sf1 = ne1 > 1 && ne01 > 1 ? (float)(ne1 - 1) / (ne01 - 1) : sf1; |
| 9470 | pixel_offset = 0.0f; |
| 9471 | } |
| 9472 | |
| 9473 | ggml_vk_op_f32<vk_op_upscale_push_constants>(ctx, subctx, src0, src1: nullptr, src2: nullptr, src3: nullptr, dst, op: GGML_OP_UPSCALE, pc: { |
| 9474 | .ne: (uint32_t)ggml_nelements(tensor: dst), .a_offset: 0, .d_offset: 0, |
| 9475 | .ne00: (uint32_t)ne00, .ne01: (uint32_t)ne01, |
| 9476 | .nb00: (uint32_t)nb00 / src0_type_size, .nb01: (uint32_t)nb01 / src0_type_size, .nb02: (uint32_t)nb02 / src0_type_size, .nb03: (uint32_t)nb03 / src0_type_size, |
| 9477 | .ne10: (uint32_t)ne0, .ne11: (uint32_t)ne1, .ne12: (uint32_t)ne2, .ne13: (uint32_t)ne3, |
| 9478 | .sf0: sf0, .sf1: sf1, .sf2: sf2, .sf3: sf3, .pixel_offset: pixel_offset |
| 9479 | }); |
| 9480 | } |
| 9481 | |
| 9482 | static void ggml_vk_scale(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { |
| 9483 | vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst); |
| 9484 | p.param1 = ggml_get_op_params_f32(tensor: dst, i: 0); |
| 9485 | p.param2 = ggml_get_op_params_f32(tensor: dst, i: 1); |
| 9486 | |
| 9487 | ggml_vk_op_f32(ctx, subctx, src0, src1: nullptr, src2: nullptr, src3: nullptr, dst, op: GGML_OP_SCALE, pc: std::move(p)); |
| 9488 | } |
| 9489 | |
| 9490 | static void ggml_vk_sqr(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { |
| 9491 | ggml_vk_op_f32(ctx, subctx, src0, src1: nullptr, src2: nullptr, src3: nullptr, dst, op: GGML_OP_SQR, pc: vk_op_unary_push_constants_init(src0, dst)); |
| 9492 | } |
| 9493 | |
| 9494 | static void ggml_vk_sqrt(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { |
| 9495 | ggml_vk_op_f32(ctx, subctx, src0, src1: nullptr, src2: nullptr, src3: nullptr, dst, op: GGML_OP_SQRT, pc: vk_op_unary_push_constants_init(src0, dst)); |
| 9496 | } |
| 9497 | |
| 9498 | static void ggml_vk_sin(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { |
| 9499 | ggml_vk_op_f32(ctx, subctx, src0, src1: nullptr, src2: nullptr, src3: nullptr, dst, op: GGML_OP_SIN, pc: vk_op_unary_push_constants_init(src0, dst)); |
| 9500 | } |
| 9501 | |
| 9502 | static void ggml_vk_cos(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { |
| 9503 | ggml_vk_op_f32(ctx, subctx, src0, src1: nullptr, src2: nullptr, src3: nullptr, dst, op: GGML_OP_COS, pc: vk_op_unary_push_constants_init(src0, dst)); |
| 9504 | } |
| 9505 | |
| 9506 | static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { |
| 9507 | vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst); |
| 9508 | p.param1 = ggml_get_op_params_f32(tensor: dst, i: 0); |
| 9509 | p.param2 = ggml_get_op_params_f32(tensor: dst, i: 1); |
| 9510 | |
| 9511 | ggml_vk_op_f32(ctx, subctx, src0, src1: nullptr, src2: nullptr, src3: nullptr, dst, op: GGML_OP_CLAMP, pc: std::move(p)); |
| 9512 | } |
| 9513 | |
| 9514 | static void ggml_vk_pad(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { |
| 9515 | vk_op_pad_push_constants p = vk_op_pad_push_constants_init(src0, dst); |
| 9516 | ggml_vk_op_f32(ctx, subctx, src0, src1: nullptr, src2: nullptr, src3: nullptr, dst, op: GGML_OP_PAD, pc: std::move(p)); |
| 9517 | } |
| 9518 | |
| 9519 | static void ggml_vk_roll(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { |
| 9520 | const int32_t s0 = ggml_get_op_params_i32(tensor: dst, i: 0); |
| 9521 | const int32_t s1 = ggml_get_op_params_i32(tensor: dst, i: 1); |
| 9522 | const int32_t s2 = ggml_get_op_params_i32(tensor: dst, i: 2); |
| 9523 | const int32_t s3 = ggml_get_op_params_i32(tensor: dst, i: 3); |
| 9524 | const uint32_t s01_packed = ((s0 + 0x8000) << 16) | (s1 + 0x8000); |
| 9525 | const uint32_t s23_packed = ((s2 + 0x8000) << 16) | (s3 + 0x8000); |
| 9526 | |
| 9527 | vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst); |
| 9528 | memcpy(dest: &p.param1, src: &s01_packed, n: sizeof(float)); |
| 9529 | memcpy(dest: &p.param2, src: &s23_packed, n: sizeof(float)); |
| 9530 | |
| 9531 | ggml_vk_op_f32(ctx, subctx, src0, src1: nullptr, src2: nullptr, src3: nullptr, dst, op: GGML_OP_ROLL, pc: std::move(p)); |
| 9532 | } |
| 9533 | |
| 9534 | static void ggml_vk_repeat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { |
| 9535 | vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ne: ggml_nelements(tensor: dst)); |
| 9536 | ggml_vk_op_f32(ctx, subctx, src0, src1: nullptr, src2: nullptr, src3: nullptr, dst, op: GGML_OP_REPEAT, pc: std::move(p)); |
| 9537 | } |
| 9538 | |
| 9539 | static void ggml_vk_repeat_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { |
| 9540 | vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ne: ggml_nelements(tensor: dst)); |
| 9541 | ggml_vk_op_f32(ctx, subctx, src0, src1: nullptr, src2: nullptr, src3: nullptr, dst, op: GGML_OP_REPEAT_BACK, pc: std::move(p)); |
| 9542 | } |
| 9543 | |
| 9544 | static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { |
| 9545 | uint32_t ne = (uint32_t)ggml_nelements(tensor: src0); |
| 9546 | if (ggml_is_quantized(type: src0->type) && ggml_is_quantized(type: dst->type)) { |
| 9547 | // Convert from number of logical elements to 2- or 4-byte units. |
| 9548 | ne /= ggml_blck_size(type: src0->type); |
| 9549 | if ((ggml_type_size(type: src0->type) % 4) == 0) { |
| 9550 | ne *= ggml_type_size(type: src0->type) / 4; |
| 9551 | } else { |
| 9552 | ne *= ggml_type_size(type: src0->type) / 2; |
| 9553 | } |
| 9554 | } |
| 9555 | |
| 9556 | vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ne); |
| 9557 | ggml_vk_op_f32(ctx, subctx, src0, src1: nullptr, src2: nullptr, src3: nullptr, dst, op: GGML_OP_CPY, pc: std::move(p)); |
| 9558 | } |
| 9559 | |
| 9560 | static void ggml_vk_set_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { |
| 9561 | const uint32_t src0_type_size = ggml_type_size(type: src0->type); |
| 9562 | const uint32_t src1_type_size = ggml_type_size(type: src1->type); |
| 9563 | const uint32_t dst_type_size = ggml_type_size(type: dst->type); |
| 9564 | |
| 9565 | // Skip empty skip_rows operations. For most ops the empty check at the start |
| 9566 | // of ggml_vk_build_graph is sufficient, but set_rows can have a nonempty dst |
| 9567 | // with empty srcs. |
| 9568 | if (ggml_is_empty(tensor: src0) || ggml_is_empty(tensor: src1)) { |
| 9569 | return; |
| 9570 | } |
| 9571 | |
| 9572 | ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, src2: nullptr, src3: nullptr, dst, op: GGML_OP_SET_ROWS, pc: { |
| 9573 | .ne: (uint32_t)ggml_nelements(tensor: src0), |
| 9574 | .ne00: (uint32_t)src0->ne[0], .ne01: (uint32_t)src0->ne[1], .ne02: (uint32_t)src0->ne[2],.ne03: (uint32_t)src0->ne[3], .nb00: (uint32_t)src0->nb[0] / src0_type_size, .nb01: (uint32_t)src0->nb[1] / src0_type_size, .nb02: (uint32_t)src0->nb[2] / src0_type_size, .nb03: (uint32_t)src0->nb[3] / src0_type_size, |
| 9575 | .ne10: (uint32_t)src1->ne[0], .ne11: (uint32_t)src1->ne[1], .ne12: (uint32_t)src1->ne[2],.ne13: (uint32_t)src1->ne[3], .nb10: (uint32_t)src1->nb[0] / src1_type_size, .nb11: (uint32_t)src1->nb[1] / src1_type_size, .nb12: (uint32_t)src1->nb[2] / src1_type_size, .nb13: (uint32_t)src1->nb[3] / src1_type_size, |
| 9576 | .ne20: (uint32_t) dst->ne[0], .ne21: (uint32_t) dst->ne[1], .ne22: (uint32_t) dst->ne[2],.ne23: (uint32_t) dst->ne[3], .nb20: (uint32_t) dst->nb[0] / dst_type_size, .nb21: (uint32_t) dst->nb[1] / dst_type_size, .nb22: (uint32_t) dst->nb[2] / dst_type_size, .nb23: (uint32_t) dst->nb[3] / dst_type_size, |
| 9577 | .misalign_offsets: 0, |
| 9578 | .param1: 0.0f, .param2: 0.0f, .param3: 0, |
| 9579 | }); |
| 9580 | } |
| 9581 | |
| 9582 | static void ggml_vk_silu_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { |
| 9583 | ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, src2: nullptr, src3: nullptr, dst, op: GGML_OP_SILU_BACK, pc: { .KX: (uint32_t)ggml_nelements(tensor: src0), .KY: 0, .param1: 0.0f, .param2: 0.0f }); |
| 9584 | } |
| 9585 | |
| 9586 | static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { |
| 9587 | float * op_params = (float *)dst->op_params; |
| 9588 | |
| 9589 | ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1: nullptr, src2: nullptr, src3: nullptr, dst, op: GGML_OP_NORM, pc: { .KX: (uint32_t)src0->ne[0], .KY: (uint32_t)src0->ne[1], .param1: op_params[0], .param2: 0.0f }); |
| 9590 | } |
| 9591 | |
| 9592 | static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { |
| 9593 | const int * int_op_params = (const int *)dst->op_params; |
| 9594 | const float * float_op_params = (const float *)dst->op_params; |
| 9595 | |
| 9596 | const uint32_t num_groups = int_op_params[0]; |
| 9597 | const float eps = float_op_params[1]; |
| 9598 | const uint32_t group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups); |
| 9599 | |
| 9600 | ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1: nullptr, src2: nullptr, src3: nullptr, dst, op: GGML_OP_GROUP_NORM, pc: { .KX: group_size, .KY: 0, .param1: eps, .param2: 0.0f }); |
| 9601 | } |
| 9602 | |
| 9603 | static uint32_t ggml_vk_rms_num_partials(ggml_backend_vk_context * ctx, const ggml_tensor *node) { |
| 9604 | const uint32_t ne = (uint32_t)node->ne[0]; |
| 9605 | const uint32_t denom = ctx->device->pipeline_add_rms[0][0][0]->wg_denoms[0]; |
| 9606 | const uint32_t num_partials = CEIL_DIV(ne, denom); |
| 9607 | return num_partials; |
| 9608 | } |
| 9609 | |
| 9610 | static uint32_t ggml_vk_rms_partials_size(ggml_backend_vk_context * ctx, const ggml_tensor *node) { |
| 9611 | const uint32_t num_partials = ggml_vk_rms_num_partials(ctx, node); |
| 9612 | const uint32_t num_bytes = ROUNDUP_POW2(num_partials * sizeof(uint32_t), ctx->device->partials_binding_alignment); |
| 9613 | return num_bytes; |
| 9614 | } |
| 9615 | |
| 9616 | static vk_op_rope_push_constants ggml_vk_make_rope_constants(const ggml_tensor *dst, const ggml_tensor *src0, const bool has_ff, bool backprop, const uint32_t set_rows_stride) { |
| 9617 | const int n_dims = ((const int32_t *) dst->op_params)[1]; |
| 9618 | const int mode = ((const int32_t *) dst->op_params)[2]; |
| 9619 | // const int n_ctx = ((const int32_t *) dst->op_params)[3]; |
| 9620 | const int n_ctx_orig = ((const int32_t *) dst->op_params)[4]; |
| 9621 | const float freq_base = ((const float *) dst->op_params)[5]; |
| 9622 | const float freq_scale = ((const float *) dst->op_params)[6]; |
| 9623 | const float ext_factor = ((const float *) dst->op_params)[7]; |
| 9624 | const float attn_factor = ((const float *) dst->op_params)[8]; |
| 9625 | const float beta_fast = ((const float *) dst->op_params)[9]; |
| 9626 | const float beta_slow = ((const float *) dst->op_params)[10]; |
| 9627 | int sections[4] {}; |
| 9628 | if (mode & GGML_ROPE_TYPE_MROPE) { |
| 9629 | memcpy(dest: sections, src: (const int32_t *) dst->op_params + 11, n: sizeof(int)*4); |
| 9630 | } |
| 9631 | |
| 9632 | const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE; |
| 9633 | |
| 9634 | float corr_dims[2]; |
| 9635 | ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, dims: corr_dims); |
| 9636 | |
| 9637 | const float theta_scale = powf(x: freq_base, y: -2.0f/n_dims); |
| 9638 | |
| 9639 | uint32_t nb01 = src0->nb[1] / ggml_type_size(type: src0->type); |
| 9640 | uint32_t nb02 = src0->nb[2] / ggml_type_size(type: src0->type); |
| 9641 | |
| 9642 | vk_op_rope_push_constants rope { |
| 9643 | .rope_mode: (uint32_t)mode, .ncols: (uint32_t)src0->ne[0], .n_dims: (uint32_t)n_dims, .freq_scale: freq_scale, .p_delta_rows: (uint32_t)src0->ne[1], |
| 9644 | .freq_base: freq_base, .ext_factor: ext_factor, .attn_factor: attn_factor, .corr_dims: {corr_dims[0], corr_dims[1]}, .theta_scale: theta_scale, |
| 9645 | .has_ff: has_ff, .ne02: (uint32_t)src0->ne[2], .s1: nb01, .s2: nb02, |
| 9646 | .sections: { sections[0], sections[1], sections[2], sections[3] }, .is_imrope: is_imrope, .is_back: backprop, .set_rows_stride: set_rows_stride, |
| 9647 | }; |
| 9648 | |
| 9649 | return rope; |
| 9650 | } |
| 9651 | |
| 9652 | static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx, float * op_params) { |
| 9653 | ggml_tensor * dst; |
| 9654 | const ggml_tensor * src0; |
| 9655 | const ggml_tensor * src1; |
| 9656 | |
| 9657 | if (ctx->num_additional_fused_ops > 0) { |
| 9658 | // fused rms_norm + mul |
| 9659 | ggml_tensor *mul = cgraph->nodes[node_idx + 1]; |
| 9660 | ggml_tensor *other_src = mul->src[0] == cgraph->nodes[node_idx + 0] ? mul->src[1] : mul->src[0]; |
| 9661 | dst = mul; |
| 9662 | src0 = cgraph->nodes[node_idx]->src[0]; |
| 9663 | src1 = other_src; |
| 9664 | } else { |
| 9665 | dst = cgraph->nodes[node_idx]; |
| 9666 | src0 = src1 = dst->src[0]; |
| 9667 | } |
| 9668 | |
| 9669 | const uint32_t src0_type_size = ggml_type_size(type: src0->type); |
| 9670 | const uint32_t src1_type_size = ggml_type_size(type: src1->type); |
| 9671 | const uint32_t dst_type_size = ggml_type_size(type: dst->type); |
| 9672 | |
| 9673 | uint32_t param3 = ctx->do_add_rms_partials ? ggml_vk_rms_num_partials(ctx, node: dst) : 0; |
| 9674 | |
| 9675 | vk_op_binary_push_constants bin { |
| 9676 | .ne: (uint32_t)ggml_nelements(tensor: src0), |
| 9677 | .ne00: (uint32_t)src0->ne[0], .ne01: (uint32_t)src0->ne[1], .ne02: (uint32_t)src0->ne[2],.ne03: (uint32_t)src0->ne[3], .nb00: (uint32_t)src0->nb[0] / src0_type_size, .nb01: (uint32_t)src0->nb[1] / src0_type_size, .nb02: (uint32_t)src0->nb[2] / src0_type_size, .nb03: (uint32_t)src0->nb[3] / src0_type_size, |
| 9678 | .ne10: (uint32_t)src1->ne[0], .ne11: (uint32_t)src1->ne[1], .ne12: (uint32_t)src1->ne[2],.ne13: (uint32_t)src1->ne[3], .nb10: (uint32_t)src1->nb[0] / src1_type_size, .nb11: (uint32_t)src1->nb[1] / src1_type_size, .nb12: (uint32_t)src1->nb[2] / src1_type_size, .nb13: (uint32_t)src1->nb[3] / src1_type_size, |
| 9679 | .ne20: (uint32_t) dst->ne[0], .ne21: (uint32_t) dst->ne[1], .ne22: (uint32_t) dst->ne[2],.ne23: (uint32_t) dst->ne[3], .nb20: (uint32_t) dst->nb[0] / dst_type_size, .nb21: (uint32_t) dst->nb[1] / dst_type_size, .nb22: (uint32_t) dst->nb[2] / dst_type_size, .nb23: (uint32_t) dst->nb[3] / dst_type_size, |
| 9680 | .misalign_offsets: 0, |
| 9681 | .param1: op_params[0], .param2: 0.0f, .param3: (int32_t)param3, |
| 9682 | }; |
| 9683 | |
| 9684 | // more than one fused op means rms_norm+mul+rope |
| 9685 | if (ctx->num_additional_fused_ops > 1) { |
| 9686 | static constexpr uint32_t max_tensors = 7; |
| 9687 | const ggml_tensor *tensors[max_tensors] {}; |
| 9688 | |
| 9689 | ggml_tensor *rms = cgraph->nodes[node_idx + 0]; |
| 9690 | ggml_tensor *mul = cgraph->nodes[node_idx + 1]; |
| 9691 | ggml_tensor *rope = cgraph->nodes[node_idx + 2]; |
| 9692 | |
| 9693 | ggml_tensor *other_src = mul->src[0] == rms ? mul->src[1] : mul->src[0]; |
| 9694 | |
| 9695 | bool do_set_rows = ctx->num_additional_fused_ops == 4; |
| 9696 | |
| 9697 | tensors[0] = rms->src[0]; |
| 9698 | tensors[1] = other_src; |
| 9699 | tensors[2] = mul; |
| 9700 | tensors[3] = rope->src[1]; // pos |
| 9701 | tensors[4] = rope->src[2]; // ff |
| 9702 | tensors[5] = cgraph->nodes[node_idx + ctx->num_additional_fused_ops]; // dst |
| 9703 | tensors[6] = do_set_rows ? tensors[5]->src[1] : nullptr; |
| 9704 | const uint32_t set_rows_stride = do_set_rows ? tensors[5]->nb[1] / ggml_type_size(type: tensors[5]->type) : 0; |
| 9705 | |
| 9706 | vk_op_rms_norm_mul_rope_push_constants pc; |
| 9707 | pc.bin = bin; |
| 9708 | pc.rope = ggml_vk_make_rope_constants(dst: rope, src0: rope->src[0], has_ff: tensors[4] != nullptr, backprop: false, set_rows_stride); |
| 9709 | |
| 9710 | vk_pipeline pipeline = tensors[5]->type == GGML_TYPE_F16 ? ctx->device->pipeline_rms_norm_mul_rope_f32_f16 : ctx->device->pipeline_rms_norm_mul_rope_f32_f32; |
| 9711 | |
| 9712 | ggml_pipeline_request_descriptor_sets(ctx, pipeline, n: 1); |
| 9713 | |
| 9714 | ggml_backend_vk_buffer_context * buf_ctx[max_tensors]; |
| 9715 | vk_buffer buf[max_tensors]; |
| 9716 | size_t offset[max_tensors]; |
| 9717 | bool uma[max_tensors]; |
| 9718 | |
| 9719 | for (uint32_t i = 0; i < max_tensors; ++i) { |
| 9720 | if (!tensors[i]) { |
| 9721 | // If any remaining descriptors are unused, just point them at src[0] |
| 9722 | buf[i] = buf[0]; |
| 9723 | offset[i] = 0; |
| 9724 | continue; |
| 9725 | } |
| 9726 | buf_ctx[i] = (ggml_backend_vk_buffer_context *)tensors[i]->buffer->context; |
| 9727 | buf[i] = nullptr; |
| 9728 | offset[i] = 0; |
| 9729 | uma[i] = false; |
| 9730 | |
| 9731 | if (ctx->device->uma) { |
| 9732 | ggml_vk_host_get(device: ctx->device, ptr: tensors[i]->data, buf&: buf[i], buf_offset&: offset[i]); |
| 9733 | uma[i] = buf[i] != nullptr; |
| 9734 | } |
| 9735 | if (!uma[i]) { |
| 9736 | buf[i] = buf_ctx[i]->dev_buffer; |
| 9737 | offset[i] = vk_tensor_offset(tensor: tensors[i]) + tensors[i]->view_offs; |
| 9738 | } |
| 9739 | GGML_ASSERT(buf[i] != nullptr); |
| 9740 | } |
| 9741 | |
| 9742 | std::array<uint32_t, 3> elements; |
| 9743 | elements = { (uint32_t)rms->src[0]->ne[1], (uint32_t)rms->src[0]->ne[2], (uint32_t)rms->src[0]->ne[3] }; |
| 9744 | |
| 9745 | static_assert(max_tensors == 7); |
| 9746 | ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, |
| 9747 | descriptor_buffer_infos: { |
| 9748 | ggml_vk_subbuffer(ctx, buf: buf[0], offset: offset[0]), |
| 9749 | ggml_vk_subbuffer(ctx, buf: buf[1], offset: offset[1]), |
| 9750 | ggml_vk_subbuffer(ctx, buf: buf[2], offset: offset[2]), |
| 9751 | ggml_vk_subbuffer(ctx, buf: buf[3], offset: offset[3]), |
| 9752 | ggml_vk_subbuffer(ctx, buf: buf[4], offset: offset[4]), |
| 9753 | ggml_vk_subbuffer(ctx, buf: buf[5], offset: offset[5]), |
| 9754 | ggml_vk_subbuffer(ctx, buf: buf[6], offset: offset[6]), |
| 9755 | }, push_constants: pc, elements); |
| 9756 | } else { |
| 9757 | ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, src2: nullptr, src3: nullptr, dst, op: GGML_OP_RMS_NORM, pc: std::move(bin)); |
| 9758 | } |
| 9759 | |
| 9760 | if (ctx->do_add_rms_partials_offset_calculation) { |
| 9761 | ctx->prealloc_size_add_rms_partials_offset += ggml_vk_rms_partials_size(ctx, node: src0); |
| 9762 | ctx->do_add_rms_partials = false; |
| 9763 | ctx->do_add_rms_partials_offset_calculation = false; |
| 9764 | } |
| 9765 | } |
| 9766 | |
| 9767 | static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { |
| 9768 | float * op_params = (float *)dst->op_params; |
| 9769 | ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, src2: nullptr, src3: nullptr, dst, op: GGML_OP_RMS_NORM_BACK, pc: { .KX: (uint32_t)src0->ne[0], .KY: (uint32_t)src0->ne[1], .param1: op_params[0], .param2: 0.0f }); |
| 9770 | } |
| 9771 | |
| 9772 | static void ggml_vk_l2_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { |
| 9773 | float * op_params = (float *)dst->op_params; |
| 9774 | ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1: nullptr, src2: nullptr, src3: nullptr, dst, op: GGML_OP_L2_NORM, pc: { .KX: (uint32_t)src0->ne[0], .KY: (uint32_t)src0->ne[1], .param1: op_params[0], .param2: 0.0f }); |
| 9775 | } |
| 9776 | |
| 9777 | static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { |
| 9778 | ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1: nullptr, src2: nullptr, src3: nullptr, dst, op: GGML_OP_UNARY, pc: { .KX: (uint32_t)ggml_nelements(tensor: src0), .KY: 0, .param1: 0.0f, .param2: 0.0f }); |
| 9779 | } |
| 9780 | |
| 9781 | static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { |
| 9782 | const float * op_params_f = (const float *)dst->op_params; |
| 9783 | |
| 9784 | const bool swapped = (bool)dst->op_params[1]; |
| 9785 | const bool split = src1 != nullptr; |
| 9786 | const float alpha = op_params_f[2]; |
| 9787 | const float limit = op_params_f[3]; |
| 9788 | |
| 9789 | GGML_ASSERT(ggml_is_contiguous(src0)); |
| 9790 | |
| 9791 | if (!split) { |
| 9792 | GGML_ASSERT(src0->ne[0] / 2 == dst->ne[0]); |
| 9793 | } else { |
| 9794 | GGML_ASSERT(src0->ne[0] == src1->ne[0]); |
| 9795 | GGML_ASSERT(src0->ne[0] == dst->ne[0]); |
| 9796 | GGML_ASSERT(src0->type == src1->type); |
| 9797 | } |
| 9798 | |
| 9799 | const uint32_t mode = split ? 2 : (swapped ? 1 : 0); |
| 9800 | |
| 9801 | ggml_vk_op_f32<vk_op_glu_push_constants>(ctx, subctx, src0, src1, src2: nullptr, src3: nullptr, dst, op: GGML_OP_GLU, |
| 9802 | pc: { |
| 9803 | .N: (uint32_t)ggml_nelements(tensor: dst), |
| 9804 | .ne00: (uint32_t)src0->ne[0], |
| 9805 | .ne20: (uint32_t)dst->ne[0], |
| 9806 | .mode: mode, |
| 9807 | .alpha: alpha, |
| 9808 | .limit: limit |
| 9809 | }); |
| 9810 | } |
| 9811 | |
| 9812 | static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { |
| 9813 | int32_t * op_params = (int32_t *)dst->op_params; |
| 9814 | ggml_vk_op_f32<vk_op_diag_mask_push_constants>(ctx, subctx, src0, src1: nullptr, src2: nullptr, src3: nullptr, dst, op: GGML_OP_DIAG_MASK_INF, pc: { .ncols: (uint32_t)src0->ne[0], .rows_per_channel: (uint32_t)src0->ne[1], .n_past: op_params[0] }); |
| 9815 | } |
| 9816 | |
| 9817 | static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { |
| 9818 | float * op_params = (float *)dst->op_params; |
| 9819 | |
| 9820 | float scale = op_params[0]; |
| 9821 | float max_bias = op_params[1]; |
| 9822 | |
| 9823 | const uint32_t ncols = (uint32_t)src0->ne[0]; |
| 9824 | const uint32_t nrows_x = (uint32_t)ggml_nrows(tensor: src0); |
| 9825 | const uint32_t nrows_y = (uint32_t)src0->ne[1]; |
| 9826 | |
| 9827 | const uint32_t ne12 = src1 ? (uint32_t)(src1->ne[2]) : 0u; |
| 9828 | const uint32_t ne13 = src1 ? (uint32_t)(src1->ne[3]) : 0u; |
| 9829 | const uint32_t nb11 = src1 ? (uint32_t)(src1->nb[1] / src1->nb[0]) : 0u; |
| 9830 | const uint32_t nb12 = src1 ? (uint32_t)(src1->nb[2] / src1->nb[0]) : 0u; |
| 9831 | const uint32_t nb13 = src1 ? (uint32_t)(src1->nb[3] / src1->nb[0]) : 0u; |
| 9832 | |
| 9833 | const uint32_t n_head_kv = src0->ne[2]; |
| 9834 | const uint32_t n_head_log2 = 1u << (uint32_t) floorf(x: log2f(x: (float) n_head_kv)); |
| 9835 | |
| 9836 | const float m0 = powf(x: 2.0f, y: -(max_bias ) / n_head_log2); |
| 9837 | const float m1 = powf(x: 2.0f, y: -(max_bias / 2.0f) / n_head_log2); |
| 9838 | |
| 9839 | ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, src2, src3: nullptr, dst, op: GGML_OP_SOFT_MAX, pc: { |
| 9840 | .KX: ncols, |
| 9841 | .KY: src1 != nullptr ? nrows_y : (uint32_t)0, |
| 9842 | .ne00: (uint32_t)src0->ne[0], .ne01: (uint32_t)src0->ne[1], .ne02: (uint32_t)src0->ne[2], |
| 9843 | .ne12: ne12, .ne13: ne13, |
| 9844 | .nb11: nb11, .nb12: nb12, .nb13: nb13, |
| 9845 | .scale: scale, .max_bias: max_bias, |
| 9846 | .m0: m0, .m1: m1, |
| 9847 | .n_head_log2: n_head_log2, |
| 9848 | .nrows_x: nrows_x, |
| 9849 | .has_sinks: src2 != nullptr |
| 9850 | }); |
| 9851 | } |
| 9852 | |
| 9853 | static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { |
| 9854 | float * op_params = (float *)dst->op_params; |
| 9855 | ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, src2: nullptr, src3: nullptr, dst, op: GGML_OP_SOFT_MAX_BACK, pc: { .KX: (uint32_t)src0->ne[0], .KY: (uint32_t)ggml_nrows(tensor: src0), .param1: op_params[0], .param2: op_params[1] }); |
| 9856 | } |
| 9857 | |
| 9858 | static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx) { |
| 9859 | topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(num: ctx->num_additional_fused_ops); |
| 9860 | ggml_tensor * logits = cgraph->nodes[node_idx + 0]->src[0]; |
| 9861 | ggml_tensor * weights = (mode == TOPK_MOE_EARLY_SOFTMAX_NORM) ? cgraph->nodes[node_idx + 9] : |
| 9862 | (mode == TOPK_MOE_EARLY_SOFTMAX) ? cgraph->nodes[node_idx + 4] : |
| 9863 | cgraph->nodes[node_idx + 5]; |
| 9864 | ggml_tensor * ids = (mode == TOPK_MOE_LATE_SOFTMAX) ? cgraph->nodes[node_idx + 1] : cgraph->nodes[node_idx + 3]; |
| 9865 | |
| 9866 | GGML_ASSERT(logits->type == GGML_TYPE_F32); |
| 9867 | GGML_ASSERT(weights->type == GGML_TYPE_F32); |
| 9868 | GGML_ASSERT(ids->type == GGML_TYPE_I32); |
| 9869 | |
| 9870 | const int n_experts = logits->ne[0]; |
| 9871 | const int n_rows = logits->ne[1]; |
| 9872 | const int n_expert_used = weights->ne[1]; |
| 9873 | |
| 9874 | GGML_ASSERT(ids->nb[1] / ggml_type_size(ids->type) == (size_t) n_experts); |
| 9875 | |
| 9876 | vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0: nullptr, src1: nullptr, src2: nullptr, dst: cgraph->nodes[node_idx], op: GGML_OP_SOFT_MAX); |
| 9877 | |
| 9878 | ggml_pipeline_request_descriptor_sets(ctx, pipeline, n: 1); |
| 9879 | |
| 9880 | vk_subbuffer logits_buf = ggml_vk_tensor_subbuffer(ctx, tensor: logits); |
| 9881 | vk_subbuffer weights_buf = ggml_vk_tensor_subbuffer(ctx, tensor: weights); |
| 9882 | vk_subbuffer ids_buf = ggml_vk_tensor_subbuffer(ctx, tensor: ids); |
| 9883 | |
| 9884 | vk_op_topk_moe_push_constants pc {}; |
| 9885 | pc.n_rows = n_rows; |
| 9886 | pc.n_expert_used = n_expert_used; |
| 9887 | if (mode == TOPK_MOE_EARLY_SOFTMAX_NORM) { |
| 9888 | ggml_tensor * clamp = cgraph->nodes[node_idx + 7]; |
| 9889 | pc.clamp_min = ggml_get_op_params_f32(tensor: clamp, i: 0); |
| 9890 | pc.clamp_max = ggml_get_op_params_f32(tensor: clamp, i: 1); |
| 9891 | } |
| 9892 | |
| 9893 | GGML_ASSERT(n_expert_used <= n_experts); |
| 9894 | |
| 9895 | const uint32_t rows_per_block = 4; |
| 9896 | std::array<uint32_t, 3> elements = { CEIL_DIV(n_rows, rows_per_block), 1, 1 }; |
| 9897 | |
| 9898 | ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, descriptor_buffer_infos: {logits_buf, weights_buf, ids_buf}, push_constants: pc, elements); |
| 9899 | } |
| 9900 | |
| 9901 | static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_cgraph * cgraph, int node_idx, bool backprop) { |
| 9902 | ggml_tensor * dst = cgraph->nodes[node_idx]; |
| 9903 | const ggml_tensor * src0 = dst->src[0]; |
| 9904 | const ggml_tensor * src1 = dst->src[1]; |
| 9905 | const ggml_tensor * src2 = dst->src[2]; |
| 9906 | const ggml_tensor * src3 = nullptr; |
| 9907 | const int n_dims = ((int32_t *) dst->op_params)[1]; |
| 9908 | const int mode = ((int32_t *) dst->op_params)[2]; |
| 9909 | // const int n_ctx = ((int32_t *) dst->op_params)[3]; |
| 9910 | const int n_ctx_orig = ((int32_t *) dst->op_params)[4]; |
| 9911 | const float freq_base = ((float *) dst->op_params)[5]; |
| 9912 | const float beta_fast = ((float *) dst->op_params)[9]; |
| 9913 | const float beta_slow = ((float *) dst->op_params)[10]; |
| 9914 | int sections[4] {}; |
| 9915 | if (mode & GGML_ROPE_TYPE_MROPE) { |
| 9916 | memcpy(dest: sections, src: (int32_t *) dst->op_params + 11, n: sizeof(int)*4); |
| 9917 | } |
| 9918 | |
| 9919 | float corr_dims[2]; |
| 9920 | ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, dims: corr_dims); |
| 9921 | |
| 9922 | uint32_t set_rows_stride = 0; |
| 9923 | // Fused rope + view + set_rows passes the set_rows destination stride in set_rows_stride |
| 9924 | // and overrides the dst and sets src3=row_indices |
| 9925 | if (ctx->num_additional_fused_ops > 0) { |
| 9926 | set_rows_stride = cgraph->nodes[node_idx + 2]->nb[1] / ggml_type_size(type: cgraph->nodes[node_idx + 2]->type); |
| 9927 | src3 = cgraph->nodes[node_idx + 2]->src[1]; |
| 9928 | dst = cgraph->nodes[node_idx + 2]; |
| 9929 | } |
| 9930 | |
| 9931 | ggml_vk_op_f32<vk_op_rope_push_constants>(ctx, subctx, src0, src1, src2, src3, dst, op: GGML_OP_ROPE, |
| 9932 | pc: ggml_vk_make_rope_constants(dst: cgraph->nodes[node_idx], src0, has_ff: src2 != nullptr, backprop, set_rows_stride)); |
| 9933 | } |
| 9934 | |
| 9935 | static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { |
| 9936 | int32_t * op_params = (int32_t *)dst->op_params; |
| 9937 | |
| 9938 | uint32_t ncols = src0->ne[0]; |
| 9939 | uint32_t nrows = ggml_nrows(tensor: src0); |
| 9940 | |
| 9941 | ggml_vk_op_f32<vk_op_argsort_push_constants>(ctx, subctx, src0, src1: nullptr, src2: nullptr, src3: nullptr, dst, op: GGML_OP_ARGSORT, pc: { |
| 9942 | .ncols: ncols, |
| 9943 | .nrows: nrows, |
| 9944 | .order: op_params[0], |
| 9945 | }); |
| 9946 | } |
| 9947 | |
| 9948 | static void ggml_vk_sum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { |
| 9949 | vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src: src0, dst, n_cols: ggml_nelements(tensor: src0)); |
| 9950 | ggml_vk_op_f32(ctx, subctx, src0, src1: nullptr, src2: nullptr, src3: nullptr, dst, op: GGML_OP_SUM, pc&: p); |
| 9951 | } |
| 9952 | |
| 9953 | static void ggml_vk_sum_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { |
| 9954 | vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src: src0, dst, n_cols: src0->ne[0]); |
| 9955 | ggml_vk_op_f32(ctx, subctx, src0, src1: nullptr, src2: nullptr, src3: nullptr, dst, op: GGML_OP_SUM_ROWS, pc&: p); |
| 9956 | } |
| 9957 | |
| 9958 | static void ggml_vk_mean(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { |
| 9959 | vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src: src0, dst, n_cols: src0->ne[0]); |
| 9960 | p.weight = 1.0f / (float)src0->ne[0]; |
| 9961 | ggml_vk_op_f32(ctx, subctx, src0, src1: nullptr, src2: nullptr, src3: nullptr, dst, op: GGML_OP_MEAN, pc&: p); |
| 9962 | } |
| 9963 | |
| 9964 | static void ggml_vk_argmax(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { |
| 9965 | ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1: nullptr, src2: nullptr, src3: nullptr, dst, op: GGML_OP_ARGMAX, pc: { .KX: (uint32_t)src0->ne[0], .KY: (uint32_t)src0->ne[1], .param1: 0.0f, .param2: 0.0f }); |
| 9966 | } |
| 9967 | |
| 9968 | static void ggml_vk_count_equal(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { |
| 9969 | ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, src2: nullptr, src3: nullptr, dst, op: GGML_OP_COUNT_EQUAL, pc: { .KX: (uint32_t)ggml_nelements(tensor: src0), .KY: 0, .param1: 0.0f, .param2: 0.0f }); |
| 9970 | } |
| 9971 | |
| 9972 | static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { |
| 9973 | const int32_t s0 = dst->op_params[0]; |
| 9974 | const int32_t s1 = dst->op_params[1]; |
| 9975 | const int32_t p0 = dst->op_params[2]; |
| 9976 | const int32_t p1 = dst->op_params[3]; |
| 9977 | const int32_t d0 = dst->op_params[4]; |
| 9978 | const int32_t d1 = dst->op_params[5]; |
| 9979 | |
| 9980 | const bool is_2D = dst->op_params[6] == 1; |
| 9981 | |
| 9982 | const uint32_t IC = src1->ne[is_2D ? 2 : 1]; |
| 9983 | const uint32_t IH = is_2D ? src1->ne[1] : 1; |
| 9984 | const uint32_t IW = src1->ne[0]; |
| 9985 | |
| 9986 | const uint32_t KH = is_2D ? src0->ne[1] : 1; |
| 9987 | const uint32_t KW = src0->ne[0]; |
| 9988 | |
| 9989 | const uint32_t OH = is_2D ? dst->ne[2] : 1; |
| 9990 | const uint32_t OW = dst->ne[1]; |
| 9991 | |
| 9992 | const uint32_t offset_delta = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32 |
| 9993 | const uint32_t batch_offset = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32 |
| 9994 | |
| 9995 | const uint32_t pelements = OW * KW * KH; |
| 9996 | |
| 9997 | const ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; |
| 9998 | const vk_buffer d_buf = d_buf_ctx->dev_buffer; |
| 9999 | |
| 10000 | const vk::DeviceAddress dst_addr = d_buf->bda_addr + vk_tensor_offset(tensor: dst) + dst->view_offs; |
| 10001 | |
| 10002 | ggml_vk_op_f32<vk_op_im2col_push_constants>(ctx, subctx, src0, src1, src2: nullptr, src3: nullptr, dst, op: GGML_OP_IM2COL, pc: { |
| 10003 | .dst_addr: dst_addr, |
| 10004 | .batch_offset: batch_offset, .offset_delta: offset_delta, |
| 10005 | .IC: IC, .IW: IW, .IH: IH, .OW: OW, .OH: OH, .KW: KW, .KH: KH, |
| 10006 | .pelements: pelements, |
| 10007 | .CHW: IC * KH * KW, |
| 10008 | .s0: s0, .s1: s1, .p0: p0, .p1: p1, .d0: d0, .d1: d1, |
| 10009 | }); |
| 10010 | } |
| 10011 | |
| 10012 | static void ggml_vk_im2col_3d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { |
| 10013 | GGML_TENSOR_BINARY_OP_LOCALS |
| 10014 | |
| 10015 | const int32_t s0 = ((const int32_t *)(dst->op_params))[0]; |
| 10016 | const int32_t s1 = ((const int32_t *)(dst->op_params))[1]; |
| 10017 | const int32_t s2 = ((const int32_t *)(dst->op_params))[2]; |
| 10018 | const int32_t p0 = ((const int32_t *)(dst->op_params))[3]; |
| 10019 | const int32_t p1 = ((const int32_t *)(dst->op_params))[4]; |
| 10020 | const int32_t p2 = ((const int32_t *)(dst->op_params))[5]; |
| 10021 | const int32_t d0 = ((const int32_t *)(dst->op_params))[6]; |
| 10022 | const int32_t d1 = ((const int32_t *)(dst->op_params))[7]; |
| 10023 | const int32_t d2 = ((const int32_t *)(dst->op_params))[8]; |
| 10024 | const int32_t IC = ((const int32_t *)(dst->op_params))[9]; |
| 10025 | |
| 10026 | const int64_t N = ne13 / IC; |
| 10027 | const int64_t ID = ne12; |
| 10028 | const int64_t IH = ne11; |
| 10029 | const int64_t IW = ne10; |
| 10030 | |
| 10031 | const int64_t KD = ne02; |
| 10032 | const int64_t KH = ne01; |
| 10033 | const int64_t KW = ne00; |
| 10034 | |
| 10035 | const int64_t OD = ne3 / N; |
| 10036 | const int64_t OH = ne2; |
| 10037 | const int64_t OW = ne1; |
| 10038 | |
| 10039 | const ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; |
| 10040 | const vk_buffer d_buf = d_buf_ctx->dev_buffer; |
| 10041 | |
| 10042 | const vk::DeviceAddress dst_addr = d_buf->bda_addr + vk_tensor_offset(tensor: dst) + dst->view_offs; |
| 10043 | |
| 10044 | vk_op_im2col_3d_push_constants pc {}; |
| 10045 | |
| 10046 | pc.dst_addr = dst_addr; |
| 10047 | pc.nb10 = nb10 / ggml_type_size(type: src1->type); |
| 10048 | pc.nb11 = nb11 / ggml_type_size(type: src1->type); |
| 10049 | pc.nb12 = nb12 / ggml_type_size(type: src1->type); |
| 10050 | pc.nb13 = nb13 / ggml_type_size(type: src1->type); |
| 10051 | pc.s0 = s0; |
| 10052 | pc.s1 = s1; |
| 10053 | pc.s2 = s2; |
| 10054 | pc.p0 = p0; |
| 10055 | pc.p1 = p1; |
| 10056 | pc.p2 = p2; |
| 10057 | pc.d0 = d0; |
| 10058 | pc.d1 = d1; |
| 10059 | pc.d2 = d2; |
| 10060 | pc.IW = IW; |
| 10061 | pc.IH = IH; |
| 10062 | pc.ID = ID; |
| 10063 | pc.IC = IC; |
| 10064 | pc.KW = KW; |
| 10065 | pc.OH = OH; |
| 10066 | pc.KD_KH_KW = KD*KH*KW; |
| 10067 | pc.KH_KW = KH*KW; |
| 10068 | pc.IC_KD_KH_KW = IC*KD*KH*KW; |
| 10069 | pc.N_OD_OH = N*OD*OH; |
| 10070 | pc.OD_OH = OD*OH; |
| 10071 | pc.OD_OH_OW_IC_KD_KH_KW = OD*OH*OW*IC*KD*KH*KW; |
| 10072 | pc.OH_OW_IC_KD_KH_KW = OH*OW*IC*KD*KH*KW; |
| 10073 | pc.OW_IC_KD_KH_KW = OW*IC*KD*KH*KW; |
| 10074 | |
| 10075 | ggml_vk_op_f32<vk_op_im2col_3d_push_constants>(ctx, subctx, src0, src1, src2: nullptr, src3: nullptr, dst, op: GGML_OP_IM2COL_3D, pc: std::move(pc)); |
| 10076 | } |
| 10077 | |
| 10078 | static void ggml_vk_timestep_embedding(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { |
| 10079 | const uint32_t dim = dst->op_params[0]; |
| 10080 | const uint32_t max_period = dst->op_params[1]; |
| 10081 | const uint32_t nb1 = dst->nb[1] / ggml_type_size(type: dst->type); |
| 10082 | |
| 10083 | ggml_vk_op_f32<vk_op_timestep_embedding_push_constants>(ctx, subctx, src0, src1: nullptr, src2: nullptr, src3: nullptr, dst, op: GGML_OP_TIMESTEP_EMBEDDING, pc: { |
| 10084 | .nb1: nb1, .dim: dim, .max_period: max_period, |
| 10085 | }); |
| 10086 | } |
| 10087 | |
| 10088 | static void ggml_vk_conv_transpose_1d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { |
| 10089 | // src0: (K, Cout, Cin, 1) -- kernel |
| 10090 | // src1: (L, Cin, 1, 1) -- input |
| 10091 | // dst: (*, Cout, 1, 1) |
| 10092 | |
| 10093 | GGML_ASSERT(src0->type == GGML_TYPE_F32); |
| 10094 | GGML_ASSERT(src1->type == GGML_TYPE_F32); |
| 10095 | GGML_ASSERT( dst->type == GGML_TYPE_F32); |
| 10096 | |
| 10097 | GGML_TENSOR_BINARY_OP_LOCALS |
| 10098 | |
| 10099 | GGML_ASSERT(nb00 == sizeof(float)); |
| 10100 | GGML_ASSERT(nb10 == sizeof(float)); |
| 10101 | |
| 10102 | const int32_t s0 = dst->op_params[0]; |
| 10103 | |
| 10104 | vk_op_conv_transpose_1d_push_constants p{}; |
| 10105 | p.Cout = static_cast<uint32_t>(ne01); |
| 10106 | p.Cin = static_cast<uint32_t>(ne02); |
| 10107 | p.K = static_cast<uint32_t>(ne00); |
| 10108 | p.L = static_cast<uint32_t>(ne10); |
| 10109 | p.KL = static_cast<uint32_t>(ne0); |
| 10110 | p.nb01 = static_cast<uint32_t>(nb01 / nb00); |
| 10111 | p.nb02 = static_cast<uint32_t>(nb02 / nb00); |
| 10112 | p.nb11 = static_cast<uint32_t>(nb11 / nb10); |
| 10113 | p.nb1 = static_cast<uint32_t>(nb1 / nb0); |
| 10114 | p.s0 = static_cast<uint32_t>(s0); |
| 10115 | |
| 10116 | ggml_vk_op_f32(ctx, subctx, src0, src1, src2: nullptr, src3: nullptr, dst, op: GGML_OP_CONV_TRANSPOSE_1D, pc: std::move(p)); |
| 10117 | } |
| 10118 | |
| 10119 | static void ggml_vk_pool_2d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { |
| 10120 | uint32_t op = static_cast<uint32_t>(dst->op_params[0]); |
| 10121 | const int32_t k1 = dst->op_params[1]; |
| 10122 | const int32_t k0 = dst->op_params[2]; |
| 10123 | const int32_t s1 = dst->op_params[3]; |
| 10124 | const int32_t s0 = dst->op_params[4]; |
| 10125 | const int32_t p1 = dst->op_params[5]; |
| 10126 | const int32_t p0 = dst->op_params[6]; |
| 10127 | |
| 10128 | const uint32_t IH = src0->ne[1]; |
| 10129 | const uint32_t IW = src0->ne[0]; |
| 10130 | |
| 10131 | const uint32_t N = dst->ne[3]; |
| 10132 | |
| 10133 | const uint32_t OC = dst->ne[2]; |
| 10134 | const uint32_t OH = dst->ne[1]; |
| 10135 | const uint32_t OW = dst->ne[0]; |
| 10136 | |
| 10137 | const uint32_t parallel_elements = N * OC * OH * OW; |
| 10138 | |
| 10139 | ggml_vk_op_f32<vk_op_pool2d_push_constants>(ctx, subctx, src0, src1: nullptr, src2: nullptr, src3: nullptr, dst, op: GGML_OP_POOL_2D, pc: { |
| 10140 | .IW: IW, .IH: IH, .OW: OW, .OH: OH, .OC: OC, |
| 10141 | .pelements: parallel_elements, |
| 10142 | .op: op, |
| 10143 | .k0: k0, .k1: k1, .s0: s0, .s1: s1, .p0: p0, .p1: p1, |
| 10144 | }); |
| 10145 | } |
| 10146 | |
| 10147 | static void ggml_vk_conv_2d(ggml_backend_vk_context * ctx, vk_context & subctx, const ggml_tensor * src0, |
| 10148 | const ggml_tensor * src1, ggml_tensor * dst) { |
| 10149 | GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); |
| 10150 | GGML_ASSERT(src1->type == GGML_TYPE_F32); |
| 10151 | GGML_ASSERT(dst->type == GGML_TYPE_F32); |
| 10152 | |
| 10153 | GGML_TENSOR_BINARY_OP_LOCALS |
| 10154 | |
| 10155 | GGML_ASSERT(nb00 == sizeof(float) || nb00 == sizeof(ggml_fp16_t)); |
| 10156 | GGML_ASSERT(nb10 == sizeof(float)); |
| 10157 | GGML_ASSERT(nb0 == sizeof(float)); |
| 10158 | |
| 10159 | vk_op_conv2d_push_constants p{}; |
| 10160 | p.Cout = static_cast<uint32_t>(ne03); |
| 10161 | p.Cin = static_cast<uint32_t>(ne02); |
| 10162 | p.N = static_cast<uint32_t>(ne13); |
| 10163 | |
| 10164 | p.KW = static_cast<uint32_t>(ne00); |
| 10165 | p.KH = static_cast<uint32_t>(ne01); |
| 10166 | p.W = static_cast<uint32_t>(ne10); |
| 10167 | p.H = static_cast<uint32_t>(ne11); |
| 10168 | p.OW = static_cast<uint32_t>(ne0); |
| 10169 | p.OH = static_cast<uint32_t>(ne1); |
| 10170 | |
| 10171 | p.s0 = static_cast<uint32_t>(dst->op_params[0]); |
| 10172 | p.s1 = static_cast<uint32_t>(dst->op_params[1]); |
| 10173 | p.p0 = static_cast<uint32_t>(dst->op_params[2]); |
| 10174 | p.p1 = static_cast<uint32_t>(dst->op_params[3]); |
| 10175 | p.d0 = static_cast<uint32_t>(dst->op_params[4]); |
| 10176 | p.d1 = static_cast<uint32_t>(dst->op_params[5]); |
| 10177 | |
| 10178 | p.nb01 = static_cast<uint32_t>(nb01 / nb00); |
| 10179 | p.nb02 = static_cast<uint32_t>(nb02 / nb00); |
| 10180 | p.nb03 = static_cast<uint32_t>(nb03 / nb00); |
| 10181 | |
| 10182 | p.nb11 = static_cast<uint32_t>(nb11 / nb10); |
| 10183 | p.nb12 = static_cast<uint32_t>(nb12 / nb10); |
| 10184 | p.nb13 = static_cast<uint32_t>(nb13 / nb10); |
| 10185 | |
| 10186 | p.nb1 = static_cast<uint32_t>(nb1 / nb0); |
| 10187 | p.nb2 = static_cast<uint32_t>(nb2 / nb0); |
| 10188 | p.nb3 = static_cast<uint32_t>(nb3 / nb0); |
| 10189 | |
| 10190 | GGML_ASSERT(ne03 == ne2); |
| 10191 | GGML_ASSERT(ne02 == ne12); |
| 10192 | |
| 10193 | ggml_vk_op_f32(ctx, subctx, src0, src1, src2: nullptr, src3: nullptr, dst, op: GGML_OP_CONV_2D, pc: std::move(p)); |
| 10194 | } |
| 10195 | |
| 10196 | static void ggml_vk_conv_transpose_2d(ggml_backend_vk_context * ctx, vk_context & subctx, const ggml_tensor * src0, |
| 10197 | const ggml_tensor * src1, ggml_tensor * dst) { |
| 10198 | GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); |
| 10199 | GGML_ASSERT(src1->type == GGML_TYPE_F32); |
| 10200 | GGML_ASSERT(dst->type == GGML_TYPE_F32); |
| 10201 | |
| 10202 | GGML_TENSOR_BINARY_OP_LOCALS |
| 10203 | |
| 10204 | GGML_ASSERT(nb00 == sizeof(float) || nb00 == sizeof(ggml_fp16_t)); |
| 10205 | GGML_ASSERT(nb10 == sizeof(float)); |
| 10206 | GGML_ASSERT(nb0 == sizeof(float)); |
| 10207 | |
| 10208 | vk_op_conv_transpose_2d_push_constants p{}; |
| 10209 | p.Cout = static_cast<uint32_t>(ne02); |
| 10210 | p.Cin = static_cast<uint32_t>(ne03); |
| 10211 | p.N = static_cast<uint32_t>(ne13); |
| 10212 | |
| 10213 | p.KW = static_cast<uint32_t>(ne00); |
| 10214 | p.KH = static_cast<uint32_t>(ne01); |
| 10215 | p.W = static_cast<uint32_t>(ne10); |
| 10216 | p.H = static_cast<uint32_t>(ne11); |
| 10217 | p.OW = static_cast<uint32_t>(ne0); |
| 10218 | p.OH = static_cast<uint32_t>(ne1); |
| 10219 | |
| 10220 | p.s0 = static_cast<uint32_t>(dst->op_params[0]); |
| 10221 | p.s1 = static_cast<uint32_t>(dst->op_params[0]); |
| 10222 | p.p0 = 0; |
| 10223 | p.p1 = 0; |
| 10224 | p.d0 = 1; |
| 10225 | p.d1 = 1; |
| 10226 | |
| 10227 | p.nb01 = static_cast<uint32_t>(nb01 / nb00); |
| 10228 | p.nb02 = static_cast<uint32_t>(nb02 / nb00); |
| 10229 | p.nb03 = static_cast<uint32_t>(nb03 / nb00); |
| 10230 | |
| 10231 | p.nb11 = static_cast<uint32_t>(nb11 / nb10); |
| 10232 | p.nb12 = static_cast<uint32_t>(nb12 / nb10); |
| 10233 | p.nb13 = static_cast<uint32_t>(nb13 / nb10); |
| 10234 | |
| 10235 | p.nb1 = static_cast<uint32_t>(nb1 / nb0); |
| 10236 | p.nb2 = static_cast<uint32_t>(nb2 / nb0); |
| 10237 | p.nb3 = static_cast<uint32_t>(nb3 / nb0); |
| 10238 | |
| 10239 | GGML_ASSERT(ne02 == ne2); |
| 10240 | GGML_ASSERT(ne03 == ne12); |
| 10241 | |
| 10242 | ggml_vk_op_f32(ctx, subctx, src0, src1, src2: nullptr, src3: nullptr, dst, op: GGML_OP_CONV_TRANSPOSE_2D, pc: std::move(p)); |
| 10243 | } |
| 10244 | |
| 10245 | static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { |
| 10246 | vk_op_conv2d_dw_push_constants p{}; |
| 10247 | p.ne = ggml_nelements(tensor: dst); |
| 10248 | p.channels = dst->ne[2]; |
| 10249 | p.batches = dst->ne[3]; |
| 10250 | p.dst_w = dst->ne[0]; |
| 10251 | p.dst_h = dst->ne[1]; |
| 10252 | p.src_w = src1->ne[0]; |
| 10253 | p.src_h = src1->ne[1]; |
| 10254 | p.knl_w = src0->ne[0]; |
| 10255 | p.knl_h = src0->ne[1]; |
| 10256 | p.stride_x = dst->op_params[0]; |
| 10257 | p.stride_y = dst->op_params[1]; |
| 10258 | p.pad_x = dst->op_params[2]; |
| 10259 | p.pad_y = dst->op_params[3]; |
| 10260 | p.dilation_x = dst->op_params[4]; |
| 10261 | p.dilation_y = dst->op_params[5]; |
| 10262 | |
| 10263 | GGML_ASSERT(src0->ne[3] == p.channels); |
| 10264 | GGML_ASSERT(src1->ne[3] == p.batches); |
| 10265 | |
| 10266 | ggml_vk_op_f32(ctx, subctx, src0, src1, src2: nullptr, src3: nullptr, dst, op: GGML_OP_CONV_2D_DW, pc: std::move(p)); |
| 10267 | } |
| 10268 | |
| 10269 | static void ggml_vk_leaky_relu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { |
| 10270 | const float * op_params = (const float *)dst->op_params; |
| 10271 | ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1: nullptr, src2: nullptr, src3: nullptr, dst, op: GGML_OP_LEAKY_RELU, pc: { .KX: (uint32_t)ggml_nelements(tensor: src0), .KY: 0, .param1: op_params[0], .param2: 0.0f }); |
| 10272 | } |
| 10273 | |
| 10274 | #ifdef GGML_VULKAN_RUN_TESTS |
| 10275 | static void ggml_vk_print_matrix_area(const void * data, ggml_type type, int ne0, int ne1, int i0, int i1, int i2) { |
| 10276 | if (type != GGML_TYPE_F32 && type != GGML_TYPE_F16) { |
| 10277 | return; |
| 10278 | } |
| 10279 | i0 = std::max(i0, 5); |
| 10280 | i1 = std::max(i1, 5); |
| 10281 | i2 = std::max(i2, 0); |
| 10282 | fprintf(stderr, " " ); |
| 10283 | for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) { |
| 10284 | fprintf(stderr, "%7d " , idx1); |
| 10285 | } |
| 10286 | fprintf(stderr, "\n" ); |
| 10287 | for (int idx0 = i0 - 5; idx0 < i0 + 5; idx0++) { |
| 10288 | fprintf(stderr, "%7d: " , idx0); |
| 10289 | for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) { |
| 10290 | if (idx0 >= 0 && idx0 < ne0 && idx1 >= 0 && idx1 < ne1) { |
| 10291 | float val; |
| 10292 | if (type == GGML_TYPE_F32) { |
| 10293 | val = *((const float *) data + i2*ne1*ne0 + idx1*ne0 + idx0); |
| 10294 | } else if (type == GGML_TYPE_F16) { |
| 10295 | val = ggml_fp16_to_fp32(*((const ggml_fp16_t *) data + i2*ne1*ne0 + idx1*ne0 + idx0)); |
| 10296 | } else { |
| 10297 | GGML_ABORT("fatal error" ); |
| 10298 | } |
| 10299 | fprintf(stderr, "% 7.2f " , val); |
| 10300 | } else { |
| 10301 | fprintf(stderr, " " ); |
| 10302 | } |
| 10303 | } |
| 10304 | fprintf(stderr, "\n" ); |
| 10305 | } |
| 10306 | } |
| 10307 | |
| 10308 | template <typename X_TYPE, typename Y_TYPE> |
| 10309 | static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t n, size_t k, size_t batch, size_t num_it, int split_k, int shader_size) { |
| 10310 | VK_LOG_DEBUG("ggml_vk_test_matmul(" << m << ", " << n << ", " << k << ", " << batch << ", " << num_it << ", " << split_k << ", " << shader_size << ")" ); |
| 10311 | const size_t x_ne = m * k * batch; |
| 10312 | const size_t y_ne = k * n * batch; |
| 10313 | const size_t d_ne = m * n * batch; |
| 10314 | |
| 10315 | vk_pipeline p; |
| 10316 | std::string shname; |
| 10317 | if (shader_size == 0) { |
| 10318 | if (std::is_same<float, X_TYPE>() && std::is_same<float, Y_TYPE>()) { |
| 10319 | p = ctx->device->pipeline_matmul_f32->a_s; |
| 10320 | shname = "F32_ALIGNED_S" ; |
| 10321 | } else if (std::is_same<float, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) { |
| 10322 | p = ctx->device->pipeline_matmul_f32_f16->a_s; |
| 10323 | shname = "F32_F16_ALIGNED_S" ; |
| 10324 | } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<float, Y_TYPE>()) { |
| 10325 | p = ctx->device->pipeline_matmul_f16_f32.f32acc->a_s; |
| 10326 | shname = "F16_F32_ALIGNED_S" ; |
| 10327 | } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) { |
| 10328 | p = ctx->device->pipeline_matmul_f16.f32acc->a_s; |
| 10329 | shname = "F16_ALIGNED_S" ; |
| 10330 | } else { |
| 10331 | GGML_ABORT("fatal error" ); |
| 10332 | } |
| 10333 | } else if (shader_size == 1) { |
| 10334 | if (std::is_same<float, X_TYPE>() && std::is_same<float, Y_TYPE>()) { |
| 10335 | p = ctx->device->pipeline_matmul_f32->a_m; |
| 10336 | shname = "F32_ALIGNED_M" ; |
| 10337 | } else if (std::is_same<float, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) { |
| 10338 | p = ctx->device->pipeline_matmul_f32_f16->a_m; |
| 10339 | shname = "F32_F16_ALIGNED_M" ; |
| 10340 | } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<float, Y_TYPE>()) { |
| 10341 | p = ctx->device->pipeline_matmul_f16_f32.f32acc->a_m; |
| 10342 | shname = "F16_F32_ALIGNED_M" ; |
| 10343 | } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) { |
| 10344 | p = ctx->device->pipeline_matmul_f16.f32acc->a_m; |
| 10345 | shname = "F16_ALIGNED_M" ; |
| 10346 | } else { |
| 10347 | GGML_ABORT("fatal error" ); |
| 10348 | } |
| 10349 | } else if (shader_size == 2) { |
| 10350 | if (std::is_same<float, X_TYPE>() && std::is_same<float, Y_TYPE>()) { |
| 10351 | p = ctx->device->pipeline_matmul_f32->a_l; |
| 10352 | shname = "F32_ALIGNED_L" ; |
| 10353 | } else if (std::is_same<float, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) { |
| 10354 | p = ctx->device->pipeline_matmul_f32_f16->a_l; |
| 10355 | shname = "F32_F16_ALIGNED_L" ; |
| 10356 | } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<float, Y_TYPE>()) { |
| 10357 | p = ctx->device->pipeline_matmul_f16_f32.f32acc->a_l; |
| 10358 | shname = "F16_F32_ALIGNED_L" ; |
| 10359 | } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) { |
| 10360 | p = ctx->device->pipeline_matmul_f16.f32acc->a_l; |
| 10361 | shname = "F16_ALIGNED_L" ; |
| 10362 | } else { |
| 10363 | GGML_ABORT("fatal error" ); |
| 10364 | } |
| 10365 | } else { |
| 10366 | GGML_ASSERT(0); |
| 10367 | } |
| 10368 | |
| 10369 | const size_t kpad = ggml_vk_align_size(k, p->align); |
| 10370 | |
| 10371 | if (k != kpad) { |
| 10372 | if (shader_size == 0) { |
| 10373 | if (std::is_same<float, X_TYPE>() && std::is_same<float, Y_TYPE>()) { |
| 10374 | p = ctx->device->pipeline_matmul_f32->s; |
| 10375 | shname = "F32_S" ; |
| 10376 | } else if (std::is_same<float, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) { |
| 10377 | p = ctx->device->pipeline_matmul_f32_f16->s; |
| 10378 | shname = "F32_F16_S" ; |
| 10379 | } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<float, Y_TYPE>()) { |
| 10380 | p = ctx->device->pipeline_matmul_f16_f32.f32acc->s; |
| 10381 | shname = "F16_F32_S" ; |
| 10382 | } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) { |
| 10383 | p = ctx->device->pipeline_matmul_f16.f32acc->s; |
| 10384 | shname = "F16_S" ; |
| 10385 | } |
| 10386 | } else if (shader_size == 1) { |
| 10387 | if (std::is_same<float, X_TYPE>() && std::is_same<float, Y_TYPE>()) { |
| 10388 | p = ctx->device->pipeline_matmul_f32->m; |
| 10389 | shname = "F32_M" ; |
| 10390 | } else if (std::is_same<float, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) { |
| 10391 | p = ctx->device->pipeline_matmul_f32_f16->m; |
| 10392 | shname = "F32_F16_M" ; |
| 10393 | } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<float, Y_TYPE>()) { |
| 10394 | p = ctx->device->pipeline_matmul_f16_f32.f32acc->m; |
| 10395 | shname = "F16_F32_M" ; |
| 10396 | } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) { |
| 10397 | p = ctx->device->pipeline_matmul_f16.f32acc->m; |
| 10398 | shname = "F16_M" ; |
| 10399 | } |
| 10400 | } else if (shader_size == 2) { |
| 10401 | if (std::is_same<float, X_TYPE>() && std::is_same<float, Y_TYPE>()) { |
| 10402 | p = ctx->device->pipeline_matmul_f32->l; |
| 10403 | shname = "F32_L" ; |
| 10404 | } else if (std::is_same<float, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) { |
| 10405 | p = ctx->device->pipeline_matmul_f32_f16->l; |
| 10406 | shname = "F32_F16_L" ; |
| 10407 | } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<float, Y_TYPE>()) { |
| 10408 | p = ctx->device->pipeline_matmul_f16_f32.f32acc->l; |
| 10409 | shname = "F16_F32_L" ; |
| 10410 | } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) { |
| 10411 | p = ctx->device->pipeline_matmul_f16.f32acc->l; |
| 10412 | shname = "F16_L" ; |
| 10413 | } |
| 10414 | } |
| 10415 | } |
| 10416 | |
| 10417 | ggml_pipeline_request_descriptor_sets(ctx, p, num_it); |
| 10418 | if (split_k > 1) { |
| 10419 | ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, num_it); |
| 10420 | |
| 10421 | if (ctx->prealloc_split_k == nullptr || ctx->prealloc_split_k->size < sizeof(float) * d_ne * split_k) { |
| 10422 | // Resize buffer |
| 10423 | if (ctx->prealloc_split_k != nullptr) { |
| 10424 | ggml_vk_destroy_buffer(ctx->prealloc_split_k); |
| 10425 | } |
| 10426 | ctx->prealloc_split_k = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne * split_k, {vk::MemoryPropertyFlagBits::eDeviceLocal}); |
| 10427 | } |
| 10428 | } |
| 10429 | |
| 10430 | ggml_pipeline_allocate_descriptor_sets(ctx); |
| 10431 | |
| 10432 | vk_buffer d_X = ggml_vk_create_buffer_check(ctx->device, sizeof(X_TYPE) * x_ne, {vk::MemoryPropertyFlagBits::eDeviceLocal}); |
| 10433 | vk_buffer d_Y = ggml_vk_create_buffer_check(ctx->device, sizeof(Y_TYPE) * y_ne, {vk::MemoryPropertyFlagBits::eDeviceLocal}); |
| 10434 | vk_buffer d_D = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne, {vk::MemoryPropertyFlagBits::eDeviceLocal}); |
| 10435 | |
| 10436 | X_TYPE* x = (X_TYPE *) malloc(sizeof(X_TYPE) * x_ne); |
| 10437 | Y_TYPE* y = (Y_TYPE *) malloc(sizeof(Y_TYPE) * y_ne); |
| 10438 | float* d = (float *) malloc(sizeof(float) * d_ne); |
| 10439 | |
| 10440 | for (size_t i = 0; i < x_ne; i++) { |
| 10441 | if (std::is_same<float, X_TYPE>()) { |
| 10442 | x[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f; |
| 10443 | // x[i] = 1.0f; |
| 10444 | // x[i] = i + 1; |
| 10445 | // x[i] = (i % k == i / k) ? 1.0f : 0.0f; |
| 10446 | } else if (std::is_same<ggml_fp16_t, X_TYPE>()) { |
| 10447 | x[i] = ggml_fp32_to_fp16((rand() / (float)RAND_MAX) * 2.0f - 1.0f); |
| 10448 | // x[i] = ggml_fp32_to_fp16(1.0f); |
| 10449 | // x[i] = ggml_fp32_to_fp16(i + 1); |
| 10450 | // x[i] = ggml_fp32_to_fp16((i % k == i / k) ? 1.0f : 0.0f); |
| 10451 | } else { |
| 10452 | GGML_ABORT("fatal error" ); |
| 10453 | } |
| 10454 | } |
| 10455 | for (size_t i = 0; i < y_ne; i++) { |
| 10456 | if (std::is_same<float, Y_TYPE>()) { |
| 10457 | y[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f; |
| 10458 | // y[i] = (i % k == i / k) ? 1.0f : 0.0f; |
| 10459 | // y[i] = i + 1; |
| 10460 | } else if (std::is_same<ggml_fp16_t, Y_TYPE>()) { |
| 10461 | y[i] = ggml_fp32_to_fp16((rand() / (float)RAND_MAX) * 2.0f - 1.0f); |
| 10462 | // y[i] = ggml_fp32_to_fp16((i % k == i / k) ? 1.0f : 0.0f); |
| 10463 | // y[i] = ggml_fp32_to_fp16(i + 1); |
| 10464 | } else { |
| 10465 | GGML_ABORT("fatal error" ); |
| 10466 | } |
| 10467 | } |
| 10468 | |
| 10469 | ggml_vk_buffer_write(d_X, 0, x, sizeof(X_TYPE) * k * m * batch); |
| 10470 | ggml_vk_buffer_write(d_Y, 0, y, sizeof(Y_TYPE) * k * n * batch); |
| 10471 | |
| 10472 | vk_context subctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); |
| 10473 | ggml_vk_ctx_begin(ctx->device, subctx); |
| 10474 | for (size_t i = 0; i < num_it; i++) { |
| 10475 | ggml_vk_matmul( |
| 10476 | ctx, subctx, p, ggml_vk_subbuffer(ctx, d_X), ggml_vk_subbuffer(ctx, d_Y), ggml_vk_subbuffer(ctx, d_D), ggml_vk_subbuffer(ctx, ctx->prealloc_split_k), |
| 10477 | m, n, k, |
| 10478 | k, k, m, k*m, k*n, m*n, |
| 10479 | split_k, batch, batch, batch, 1, 1, n |
| 10480 | ); |
| 10481 | } |
| 10482 | ggml_vk_ctx_end(subctx); |
| 10483 | |
| 10484 | auto begin = std::chrono::high_resolution_clock::now(); |
| 10485 | ggml_vk_submit(subctx, ctx->fence); |
| 10486 | VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_matmul waitForFences" ); |
| 10487 | ctx->device->device.resetFences({ ctx->fence }); |
| 10488 | ggml_vk_queue_command_pools_cleanup(ctx->device); |
| 10489 | |
| 10490 | auto end = std::chrono::high_resolution_clock::now(); |
| 10491 | double time = std::chrono::duration_cast<std::chrono::microseconds>(end-begin).count() / 1000.0; |
| 10492 | |
| 10493 | // copy dst to host |
| 10494 | ggml_vk_buffer_read(d_D, 0, d, sizeof(float) * d_ne); |
| 10495 | |
| 10496 | float * d_chk = (float *) malloc(sizeof(float) * d_ne); |
| 10497 | |
| 10498 | ggml_init_params iparams = { |
| 10499 | /*.mem_size =*/ 1024*1024*1024, |
| 10500 | /*.mem_buffer =*/ NULL, |
| 10501 | /*.no_alloc =*/ true, |
| 10502 | }; |
| 10503 | |
| 10504 | ggml_context * ggml_ctx = ggml_init(iparams); |
| 10505 | |
| 10506 | ggml_type src0_type; |
| 10507 | ggml_type src1_type; |
| 10508 | |
| 10509 | if (std::is_same<float, X_TYPE>()) { |
| 10510 | src0_type = GGML_TYPE_F32; |
| 10511 | } else if (std::is_same<ggml_fp16_t, X_TYPE>()) { |
| 10512 | src0_type = GGML_TYPE_F16; |
| 10513 | } else { |
| 10514 | GGML_ABORT("fatal error" ); |
| 10515 | } |
| 10516 | if (std::is_same<float, Y_TYPE>()) { |
| 10517 | src1_type = GGML_TYPE_F32; |
| 10518 | } else if (std::is_same<ggml_fp16_t, Y_TYPE>()) { |
| 10519 | src1_type = GGML_TYPE_F16; |
| 10520 | } else { |
| 10521 | GGML_ABORT("fatal error" ); |
| 10522 | } |
| 10523 | |
| 10524 | ggml_tensor * src0_ggml = ggml_new_tensor_3d(ggml_ctx, src0_type, k, m, batch); |
| 10525 | ggml_tensor * src1_ggml = ggml_new_tensor_3d(ggml_ctx, src1_type, k, n, batch); |
| 10526 | ggml_tensor * tensor_ggml = ggml_mul_mat(ggml_ctx, src0_ggml, src1_ggml); |
| 10527 | |
| 10528 | src0_ggml->data = x; |
| 10529 | src1_ggml->data = y; |
| 10530 | tensor_ggml->data = d_chk; |
| 10531 | |
| 10532 | ggml_cgraph * cgraph = ggml_new_graph(ggml_ctx); |
| 10533 | ggml_build_forward_expand(cgraph, tensor_ggml); |
| 10534 | |
| 10535 | ggml_graph_compute_with_ctx(ggml_ctx, cgraph, 1); |
| 10536 | |
| 10537 | ggml_free(ggml_ctx); |
| 10538 | |
| 10539 | double avg_err = 0.0; |
| 10540 | int first_err_n = -1; |
| 10541 | int first_err_m = -1; |
| 10542 | int first_err_b = -1; |
| 10543 | |
| 10544 | for (size_t i = 0; i < m*n*batch; i++) { |
| 10545 | double err = std::fabs(d[i] - d_chk[i]); |
| 10546 | avg_err += err; |
| 10547 | |
| 10548 | if ((err > 0.05f || std::isnan(err)) && first_err_n == -1) { |
| 10549 | first_err_b = i / (m * n); |
| 10550 | first_err_n = (i % (m * n)) / m; |
| 10551 | first_err_m = (i % (m * n)) % m; |
| 10552 | } |
| 10553 | } |
| 10554 | |
| 10555 | avg_err /= m * n; |
| 10556 | |
| 10557 | double tflops = 2.0*m*n*k*batch*num_it / (time / 1000.0) / (1000.0*1000.0*1000.0*1000.0); |
| 10558 | |
| 10559 | std::cerr << "TEST " << shname << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time / num_it << "ms " << tflops << " TFLOPS avg_err=" << avg_err << std::endl; |
| 10560 | |
| 10561 | if (avg_err > 0.1 || std::isnan(avg_err)) { |
| 10562 | std::cerr << "m = " << first_err_m << " n = " << first_err_n << " b = " << first_err_b << std::endl; |
| 10563 | std::cerr << "Actual result: " << std::endl << std::endl; |
| 10564 | ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); |
| 10565 | std::cerr << "Expected result: " << std::endl << std::endl; |
| 10566 | ggml_vk_print_matrix_area(d_chk, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); |
| 10567 | |
| 10568 | if (split_k > 1) { |
| 10569 | float * split_k_buf = (float *) malloc(sizeof(float) * d_ne * split_k); |
| 10570 | ggml_vk_buffer_read(ctx->prealloc_split_k, 0, split_k_buf, sizeof(float) * d_ne * split_k); |
| 10571 | |
| 10572 | std::cerr << "d_buf0: " << std::endl << std::endl; |
| 10573 | ggml_vk_print_matrix_area(split_k_buf, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); |
| 10574 | |
| 10575 | std::cerr << "d_buf1: " << std::endl << std::endl; |
| 10576 | ggml_vk_print_matrix_area(split_k_buf + d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); |
| 10577 | |
| 10578 | std::cerr << "d_buf2: " << std::endl << std::endl; |
| 10579 | ggml_vk_print_matrix_area(split_k_buf + 2 * d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); |
| 10580 | |
| 10581 | std::cerr << "d_buf3: " << std::endl << std::endl; |
| 10582 | ggml_vk_print_matrix_area(split_k_buf + 3 * d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); |
| 10583 | |
| 10584 | free(split_k_buf); |
| 10585 | } |
| 10586 | } |
| 10587 | |
| 10588 | free(d_chk); |
| 10589 | |
| 10590 | ggml_vk_command_pool_cleanup(ctx->device, ctx->compute_cmd_pool); |
| 10591 | ggml_vk_command_pool_cleanup(ctx->device, ctx->transfer_cmd_pool); |
| 10592 | |
| 10593 | ggml_vk_destroy_buffer(d_X); |
| 10594 | ggml_vk_destroy_buffer(d_Y); |
| 10595 | ggml_vk_destroy_buffer(d_D); |
| 10596 | |
| 10597 | free(x); |
| 10598 | free(y); |
| 10599 | free(d); |
| 10600 | } |
| 10601 | |
| 10602 | static void ggml_vk_print_tensor_area(const ggml_tensor * tensor, int i0, int i1, int i2, int i3) { |
| 10603 | if (tensor->type != GGML_TYPE_F32 && tensor->type != GGML_TYPE_F16) { |
| 10604 | return; |
| 10605 | } |
| 10606 | i0 = std::max(i0, 5); |
| 10607 | i1 = std::max(i1, 5); |
| 10608 | i2 = std::max(i2, 0); |
| 10609 | i3 = std::max(i3, 0); |
| 10610 | fprintf(stderr, " " ); |
| 10611 | for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) { |
| 10612 | fprintf(stderr, "%7d " , idx1); |
| 10613 | } |
| 10614 | fprintf(stderr, "\n" ); |
| 10615 | for (int idx0 = i0 - 5; idx0 < i0 + 5; idx0++) { |
| 10616 | fprintf(stderr, "%7d: " , idx0); |
| 10617 | for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) { |
| 10618 | if (idx0 >= 0 && idx0 < tensor->ne[0] && idx1 >= 0 && idx1 < tensor->ne[1] && i2 >= 0 && i2 < tensor->ne[2] && i3 >= 0 && i3 < tensor->ne[3]) { |
| 10619 | float val; |
| 10620 | if (tensor->type == GGML_TYPE_F32) { |
| 10621 | val = *(float *) ((char *) tensor->data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0]); |
| 10622 | } else if (tensor->type == GGML_TYPE_F16) { |
| 10623 | val = ggml_fp16_to_fp32(*(ggml_fp16_t *) ((char *) tensor->data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0])); |
| 10624 | } else { |
| 10625 | GGML_ABORT("fatal error" ); |
| 10626 | } |
| 10627 | fprintf(stderr, "% 7.2f " , val); |
| 10628 | } else { |
| 10629 | fprintf(stderr, " " ); |
| 10630 | } |
| 10631 | } |
| 10632 | fprintf(stderr, "\n" ); |
| 10633 | } |
| 10634 | } |
| 10635 | |
| 10636 | static void ggml_vk_quantize_data(const float * from, void * to, size_t ne, ggml_type quant) { |
| 10637 | ggml_quantize_chunk(quant, from, to, 0, 1, ne, nullptr); |
| 10638 | } |
| 10639 | |
| 10640 | static void ggml_vk_dequantize_data(const void * from, float * to, size_t ne, ggml_type quant) { |
| 10641 | if (quant == GGML_TYPE_F32) { |
| 10642 | memcpy(to, from, sizeof(float) * ne); |
| 10643 | return; |
| 10644 | } |
| 10645 | |
| 10646 | const auto * tt = ggml_get_type_traits(quant); |
| 10647 | |
| 10648 | ggml_to_float_t dequant_fn = tt->to_float; |
| 10649 | |
| 10650 | dequant_fn(from, to, ne); |
| 10651 | } |
| 10652 | |
| 10653 | static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_type quant) { |
| 10654 | VK_LOG_DEBUG("ggml_vk_test_dequant(" << ne << ")" ); |
| 10655 | const size_t x_sz = sizeof(float) * ne; |
| 10656 | const size_t x_sz_f16 = sizeof(ggml_fp16_t) * ne; |
| 10657 | const size_t qx_sz = ne * ggml_type_size(quant)/ggml_blck_size(quant); |
| 10658 | float * x = (float *) malloc(x_sz); |
| 10659 | void * qx = malloc(qx_sz); |
| 10660 | vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, {vk::MemoryPropertyFlagBits::eDeviceLocal}); |
| 10661 | vk_buffer x_buf = ggml_vk_create_buffer_check(ctx->device, x_sz_f16, {vk::MemoryPropertyFlagBits::eDeviceLocal}); |
| 10662 | float * x_ref = (float *) malloc(x_sz); |
| 10663 | ggml_fp16_t * x_chk = (ggml_fp16_t *) malloc(x_sz_f16); |
| 10664 | |
| 10665 | for (size_t i = 0; i < ne; i++) { |
| 10666 | x[i] = rand() / (float)RAND_MAX; |
| 10667 | } |
| 10668 | |
| 10669 | vk_pipeline p = ggml_vk_get_to_fp16(ctx, quant); |
| 10670 | |
| 10671 | ggml_vk_quantize_data(x, qx, ne, quant); |
| 10672 | ggml_vk_dequantize_data(qx, x_ref, ne, quant); |
| 10673 | |
| 10674 | ggml_pipeline_request_descriptor_sets(ctx, p, 1); |
| 10675 | |
| 10676 | ggml_pipeline_allocate_descriptor_sets(ctx); |
| 10677 | |
| 10678 | ggml_vk_buffer_write(qx_buf, 0, qx, qx_sz); |
| 10679 | |
| 10680 | vk_context subctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); |
| 10681 | ggml_vk_ctx_begin(ctx->device, subctx); |
| 10682 | const std::vector<uint32_t> pc = { 1, (uint32_t)ne, (uint32_t)ne, (uint32_t)ne, (uint32_t)ne }; |
| 10683 | ggml_vk_dispatch_pipeline(ctx, subctx, p, { vk_subbuffer{ qx_buf, 0, qx_sz }, vk_subbuffer{ x_buf, 0, x_sz_f16 } }, pc, { (uint32_t)ne, 1, 1}); |
| 10684 | ggml_vk_ctx_end(subctx); |
| 10685 | |
| 10686 | auto begin = std::chrono::high_resolution_clock::now(); |
| 10687 | |
| 10688 | ggml_vk_submit(subctx, ctx->fence); |
| 10689 | VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_dequant waitForFences" ); |
| 10690 | ctx->device->device.resetFences({ ctx->fence }); |
| 10691 | ggml_vk_queue_command_pools_cleanup(ctx->device); |
| 10692 | |
| 10693 | auto end = std::chrono::high_resolution_clock::now(); |
| 10694 | |
| 10695 | double ms_dequant = std::chrono::duration_cast<std::chrono::microseconds>(end-begin).count() / 1000.0; |
| 10696 | ggml_vk_buffer_read(x_buf, 0, x_chk, x_sz_f16); |
| 10697 | |
| 10698 | int first_err = -1; |
| 10699 | |
| 10700 | double avg_err = 0.0; |
| 10701 | for (size_t i = 0; i < ne; i++) { |
| 10702 | double error = std::fabs(x_ref[i] - ggml_fp16_to_fp32(x_chk[i])); |
| 10703 | avg_err += error; |
| 10704 | |
| 10705 | if (first_err < 0 && error > 0.05) { |
| 10706 | first_err = i; |
| 10707 | } |
| 10708 | } |
| 10709 | |
| 10710 | avg_err /= ne; |
| 10711 | |
| 10712 | std::cerr << "TEST DEQUANT " << ggml_type_name(quant) << " time=" << ms_dequant << "ms avg_err=" << avg_err << std::endl; |
| 10713 | |
| 10714 | if (avg_err > 0.1) { |
| 10715 | std::cerr << "first_error = " << first_err << std::endl; |
| 10716 | std::cerr << "Actual result: " << std::endl << std::endl; |
| 10717 | for (int i = std::max(0, first_err - 5); i < std::min((int)ne, first_err + 5); i++) { |
| 10718 | std::cerr << ggml_fp16_to_fp32(x_chk[i]) << ", " ; |
| 10719 | } |
| 10720 | std::cerr << std::endl << "Expected result: " << std::endl << std::endl; |
| 10721 | for (int i = std::max(0, first_err - 5); i < std::min((int)ne, first_err + 5); i++) { |
| 10722 | std::cerr << x_ref[i] << ", " ; |
| 10723 | } |
| 10724 | std::cerr << std::endl; |
| 10725 | } |
| 10726 | |
| 10727 | ggml_vk_destroy_buffer(x_buf); |
| 10728 | ggml_vk_destroy_buffer(qx_buf); |
| 10729 | |
| 10730 | free(x); |
| 10731 | free(qx); |
| 10732 | free(x_ref); |
| 10733 | free(x_chk); |
| 10734 | } |
| 10735 | |
| 10736 | // This does not work without ggml q8_1 quantization support |
| 10737 | // |
| 10738 | // typedef uint16_t ggml_half; |
| 10739 | // typedef uint32_t ggml_half2; |
| 10740 | // |
| 10741 | // #define QK8_1 32 |
| 10742 | // typedef struct { |
| 10743 | // union { |
| 10744 | // struct { |
| 10745 | // ggml_half d; // delta |
| 10746 | // ggml_half s; // d * sum(qs[i]) |
| 10747 | // } GGML_COMMON_AGGR_S; |
| 10748 | // ggml_half2 ds; |
| 10749 | // } GGML_COMMON_AGGR_U; |
| 10750 | // int8_t qs[QK8_1]; // quants |
| 10751 | // } block_q8_1; |
| 10752 | // |
| 10753 | // static void ggml_vk_test_quantize(ggml_backend_vk_context * ctx, size_t ne, ggml_type quant) { |
| 10754 | // VK_LOG_DEBUG("ggml_vk_test_quantize(" << ne << ")"); |
| 10755 | // GGML_ASSERT(quant == GGML_TYPE_Q8_1); |
| 10756 | // |
| 10757 | // const size_t x_sz = sizeof(float) * ne; |
| 10758 | // const size_t qx_sz = ne * ggml_type_size(quant)/ggml_blck_size(quant); |
| 10759 | // float * x = (float *) malloc(x_sz); |
| 10760 | // block_q8_1 * qx = (block_q8_1 *)malloc(qx_sz); |
| 10761 | // block_q8_1 * qx_res = (block_q8_1 *)malloc(qx_sz); |
| 10762 | // vk_buffer x_buf = ggml_vk_create_buffer_check(ctx->device, x_sz, {vk::MemoryPropertyFlagBits::eDeviceLocal}); |
| 10763 | // vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, {vk::MemoryPropertyFlagBits::eDeviceLocal}); |
| 10764 | // |
| 10765 | // for (size_t i = 0; i < ne; i++) { |
| 10766 | // x[i] = rand() / (float)RAND_MAX; |
| 10767 | // } |
| 10768 | // |
| 10769 | // vk_pipeline p = ggml_vk_get_quantize_pipeline(ctx, quant); |
| 10770 | // |
| 10771 | // ggml_pipeline_request_descriptor_sets(ctx, p, 1); |
| 10772 | // |
| 10773 | // ggml_pipeline_allocate_descriptor_sets(ctx); |
| 10774 | // |
| 10775 | // ggml_vk_buffer_write(x_buf, 0, x, x_sz); |
| 10776 | // |
| 10777 | // vk_context subctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); |
| 10778 | // ggml_vk_ctx_begin(ctx->device, subctx); |
| 10779 | // ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(ctx, x_buf), ggml_vk_subbuffer(ctx, qx_buf), ne); |
| 10780 | // ggml_vk_ctx_end(subctx); |
| 10781 | // |
| 10782 | // auto begin = std::chrono::high_resolution_clock::now(); |
| 10783 | // |
| 10784 | // ggml_vk_submit(subctx, ctx->fence); |
| 10785 | // VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_quantize waitForFences"); |
| 10786 | // ctx->device->device.resetFences({ ctx->fence }); |
| 10787 | // ggml_vk_queue_command_pools_cleanup(ctx->device); |
| 10788 | // |
| 10789 | // auto end = std::chrono::high_resolution_clock::now(); |
| 10790 | // |
| 10791 | // double ms_quant = std::chrono::duration_cast<std::chrono::microseconds>(end-begin).count() / 1000.0; |
| 10792 | // ggml_vk_buffer_read(qx_buf, 0, qx, qx_sz); |
| 10793 | // |
| 10794 | // ggml_vk_quantize_data(x, qx_res, ne, quant); |
| 10795 | // |
| 10796 | // int first_err = -1; |
| 10797 | // |
| 10798 | // for (size_t i = 0; i < ne / 32; i++) { |
| 10799 | // double error = std::fabs(ggml_fp16_to_fp32(qx_res[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d) - ggml_fp16_to_fp32(qx[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d)); |
| 10800 | // |
| 10801 | // if (first_err < 0 && error > 0.1) { |
| 10802 | // first_err = i; |
| 10803 | // } |
| 10804 | // |
| 10805 | // error = std::fabs(ggml_fp16_to_fp32(qx_res[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s) - ggml_fp16_to_fp32(qx[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s)); |
| 10806 | // |
| 10807 | // if (first_err < 0 && error > 0.1) { |
| 10808 | // first_err = i; |
| 10809 | // } |
| 10810 | // |
| 10811 | // for (size_t j = 0; j < 32; j++) { |
| 10812 | // uint64_t error = std::abs(qx_res[i].qs[j] - qx[i].qs[j]); |
| 10813 | // |
| 10814 | // if (first_err < 0 && error > 1) { |
| 10815 | // first_err = i; |
| 10816 | // } |
| 10817 | // } |
| 10818 | // } |
| 10819 | // |
| 10820 | // std::cerr << "TEST QUANTIZE " << ggml_type_name(quant) << " time=" << ms_quant << "ms " << (first_err == -1 ? "CORRECT" : "INCORRECT") << std::endl; |
| 10821 | // |
| 10822 | // if (first_err != -1) { |
| 10823 | // std::cerr << "first_error = " << first_err << std::endl; |
| 10824 | // std::cerr << "Actual result: " << std::endl << std::endl; |
| 10825 | // std::cout << "d=" << ggml_fp16_to_fp32(qx[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d) << " s=" << ggml_fp16_to_fp32(qx[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s) << " "; |
| 10826 | // for (size_t j = 0; j < 32; j++) { |
| 10827 | // std::cout << " qs" << j << "=" << (uint32_t)qx[first_err].qs[j] << " "; |
| 10828 | // } |
| 10829 | // std::cerr << std::endl << std::endl << "Expected result: " << std::endl << std::endl; |
| 10830 | // std::cout << "d=" << ggml_fp16_to_fp32(qx_res[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d) << " s=" << ggml_fp16_to_fp32(qx_res[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s) << " "; |
| 10831 | // for (size_t j = 0; j < 32; j++) { |
| 10832 | // std::cout << " qs" << j << "=" << (uint32_t)qx_res[first_err].qs[j] << " "; |
| 10833 | // } |
| 10834 | // std::cerr << std::endl; |
| 10835 | // } |
| 10836 | // |
| 10837 | // ggml_vk_destroy_buffer(x_buf); |
| 10838 | // ggml_vk_destroy_buffer(qx_buf); |
| 10839 | // |
| 10840 | // free(x); |
| 10841 | // free(qx); |
| 10842 | // free(qx_res); |
| 10843 | // } |
| 10844 | |
| 10845 | static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, size_t n, size_t k, size_t batch, size_t num_it, size_t split_k, size_t shader_size, ggml_type quant, bool mmq = false) { |
| 10846 | VK_LOG_DEBUG("ggml_vk_test_dequant_matmul(" << m << ", " << n << ", " << k << ", " << batch << ", " << num_it << ", " << split_k << ", " << ggml_type_name(quant) << ")" ); |
| 10847 | const size_t x_ne = m * k * batch; |
| 10848 | const size_t y_ne = k * n * batch; |
| 10849 | const size_t d_ne = m * n * batch; |
| 10850 | |
| 10851 | vk_matmul_pipeline2 * pipelines; |
| 10852 | |
| 10853 | if (mmq) { |
| 10854 | pipelines = ctx->device->pipeline_dequant_mul_mat_mat_q8_1; |
| 10855 | } else { |
| 10856 | pipelines = ctx->device->pipeline_dequant_mul_mat_mat; |
| 10857 | } |
| 10858 | |
| 10859 | const bool fp16acc = ctx->device->fp16; |
| 10860 | |
| 10861 | vk_pipeline p; |
| 10862 | std::string shname; |
| 10863 | if (shader_size == 0) { |
| 10864 | p = fp16acc ? pipelines[quant].f16acc->a_s : pipelines[quant].f32acc->a_s; |
| 10865 | shname = std::string(ggml_type_name(quant)) + "_ALIGNED_S" ; |
| 10866 | } else if (shader_size == 1) { |
| 10867 | p = fp16acc ? pipelines[quant].f16acc->a_m : pipelines[quant].f32acc->a_m; |
| 10868 | shname = std::string(ggml_type_name(quant)) + "_ALIGNED_M" ; |
| 10869 | } else if (shader_size == 2) { |
| 10870 | p = fp16acc ? pipelines[quant].f16acc->a_l : pipelines[quant].f32acc->a_l; |
| 10871 | shname = std::string(ggml_type_name(quant)) + "_ALIGNED_L" ; |
| 10872 | } else { |
| 10873 | GGML_ASSERT(0); |
| 10874 | } |
| 10875 | |
| 10876 | const size_t kpad = mmq ? 0 : ggml_vk_align_size(k, p->align); |
| 10877 | |
| 10878 | if (mmq || k != kpad) { |
| 10879 | if (shader_size == 0) { |
| 10880 | p = fp16acc ? pipelines[quant].f16acc->s : pipelines[quant].f32acc->s; |
| 10881 | shname = std::string(ggml_type_name(quant)) + "_S" ; |
| 10882 | } else if (shader_size == 1) { |
| 10883 | p = fp16acc ? pipelines[quant].f16acc->m : pipelines[quant].f32acc->m; |
| 10884 | shname = std::string(ggml_type_name(quant)) + "_M" ; |
| 10885 | } else if (shader_size == 2) { |
| 10886 | p = fp16acc ? pipelines[quant].f16acc->l : pipelines[quant].f32acc->l; |
| 10887 | shname = std::string(ggml_type_name(quant)) + "_L" ; |
| 10888 | } else { |
| 10889 | GGML_ASSERT(0); |
| 10890 | } |
| 10891 | } |
| 10892 | |
| 10893 | if (p == nullptr) { |
| 10894 | std::cerr << "error: no pipeline for ggml_vk_test_dequant_matmul " << ggml_type_name(quant) << std::endl; |
| 10895 | return; |
| 10896 | } |
| 10897 | |
| 10898 | const size_t x_sz = sizeof(float) * x_ne; |
| 10899 | const size_t y_sz = sizeof(float) * y_ne; |
| 10900 | const size_t qx_sz = x_ne * ggml_type_size(quant)/ggml_blck_size(quant); |
| 10901 | const size_t qy_sz = mmq ? y_ne * ggml_type_size(GGML_TYPE_Q8_1)/ggml_blck_size(GGML_TYPE_Q8_1) : y_sz; |
| 10902 | const size_t d_sz = sizeof(float) * d_ne; |
| 10903 | float * x = (float *) malloc(x_sz); |
| 10904 | float * y = (float *) malloc(y_sz); |
| 10905 | void * qx = malloc(qx_sz); |
| 10906 | vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, {vk::MemoryPropertyFlagBits::eDeviceLocal}); |
| 10907 | vk_buffer y_buf = ggml_vk_create_buffer_check(ctx->device, y_sz, {vk::MemoryPropertyFlagBits::eDeviceLocal}); |
| 10908 | vk_buffer qy_buf = ggml_vk_create_buffer_check(ctx->device, qy_sz, {vk::MemoryPropertyFlagBits::eDeviceLocal}); |
| 10909 | vk_buffer d_buf = ggml_vk_create_buffer_check(ctx->device, d_sz, {vk::MemoryPropertyFlagBits::eDeviceLocal}); |
| 10910 | float * d = (float *) malloc(d_sz); |
| 10911 | float * d_chk = (float *) malloc(d_sz); |
| 10912 | |
| 10913 | for (size_t i = 0; i < x_ne; i++) { |
| 10914 | x[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f; |
| 10915 | // x[i] = (i % k == i / k) ? 1.0f : 0.0f; |
| 10916 | // x[i] = i % k; |
| 10917 | } |
| 10918 | |
| 10919 | ggml_vk_quantize_data(x, qx, x_ne, quant); |
| 10920 | |
| 10921 | for (size_t i = 0; i < y_ne; i++) { |
| 10922 | y[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f; |
| 10923 | // y[i] = (i % k == i / k) ? 1.0f : 0.0f; |
| 10924 | // y[i] = i % k; |
| 10925 | } |
| 10926 | |
| 10927 | ggml_pipeline_request_descriptor_sets(ctx, p, num_it); |
| 10928 | if (split_k > 1) { |
| 10929 | ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, num_it); |
| 10930 | |
| 10931 | if (ctx->prealloc_split_k == nullptr || ctx->prealloc_split_k->size < sizeof(float) * d_ne * split_k) { |
| 10932 | // Resize buffer |
| 10933 | if (ctx->prealloc_split_k != nullptr) { |
| 10934 | ggml_vk_destroy_buffer(ctx->prealloc_split_k); |
| 10935 | } |
| 10936 | ctx->prealloc_split_k = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne * split_k, {vk::MemoryPropertyFlagBits::eDeviceLocal}); |
| 10937 | } |
| 10938 | } |
| 10939 | if (mmq) { |
| 10940 | ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_quantize_q8_1, num_it); |
| 10941 | } |
| 10942 | |
| 10943 | ggml_pipeline_allocate_descriptor_sets(ctx); |
| 10944 | |
| 10945 | ggml_vk_buffer_write(qx_buf, 0, qx, qx_sz); |
| 10946 | ggml_vk_buffer_write(y_buf, 0, y, y_sz); |
| 10947 | |
| 10948 | vk_context subctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); |
| 10949 | ggml_vk_ctx_begin(ctx->device, subctx); |
| 10950 | if (mmq) { |
| 10951 | for (size_t i = 0; i < num_it; i++) { |
| 10952 | ggml_vk_quantize_q8_1(ctx, subctx, { y_buf, 0, y_sz }, { qy_buf, 0, qy_sz }, y_ne); |
| 10953 | ggml_vk_matmul( |
| 10954 | ctx, subctx, p, { qx_buf, 0, qx_sz }, { qy_buf, 0, qy_sz }, { d_buf, 0, d_sz }, { ctx->prealloc_split_k, 0, ctx->prealloc_size_split_k }, |
| 10955 | m, n, k, |
| 10956 | k, k, m, k*m, k*n, m*n, |
| 10957 | split_k, batch, batch, batch, 1, 1, n |
| 10958 | ); |
| 10959 | } |
| 10960 | } else { |
| 10961 | for (size_t i = 0; i < num_it; i++) { |
| 10962 | ggml_vk_matmul( |
| 10963 | ctx, subctx, p, { qx_buf, 0, qx_sz }, { y_buf, 0, y_sz }, { d_buf, 0, d_sz }, { ctx->prealloc_split_k, 0, ctx->prealloc_size_split_k }, |
| 10964 | m, n, k, |
| 10965 | k, k, m, k*m, k*n, m*n, |
| 10966 | split_k, batch, batch, batch, 1, 1, n |
| 10967 | ); |
| 10968 | } |
| 10969 | } |
| 10970 | ggml_vk_ctx_end(subctx); |
| 10971 | |
| 10972 | auto begin = std::chrono::high_resolution_clock::now(); |
| 10973 | |
| 10974 | ggml_vk_submit(subctx, ctx->fence); |
| 10975 | VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_dequant waitForFences" ); |
| 10976 | ctx->device->device.resetFences({ ctx->fence }); |
| 10977 | ggml_vk_queue_command_pools_cleanup(ctx->device); |
| 10978 | |
| 10979 | auto end = std::chrono::high_resolution_clock::now(); |
| 10980 | |
| 10981 | double time_ms = std::chrono::duration_cast<std::chrono::microseconds>(end-begin).count() / 1000.0; |
| 10982 | ggml_vk_buffer_read(d_buf, 0, d, d_sz); |
| 10983 | |
| 10984 | ggml_init_params iparams = { |
| 10985 | /*.mem_size =*/ 1024*1024*1024, |
| 10986 | /*.mem_buffer =*/ NULL, |
| 10987 | /*.no_alloc =*/ true, |
| 10988 | }; |
| 10989 | |
| 10990 | ggml_context * ggml_ctx = ggml_init(iparams); |
| 10991 | |
| 10992 | ggml_tensor * src0_ggml = ggml_new_tensor_3d(ggml_ctx, quant, k, m, batch); |
| 10993 | ggml_tensor * src1_ggml = ggml_new_tensor_3d(ggml_ctx, GGML_TYPE_F32, k, n, batch); |
| 10994 | ggml_tensor * tensor_ggml = ggml_mul_mat(ggml_ctx, src0_ggml, src1_ggml); |
| 10995 | |
| 10996 | src0_ggml->data = qx; |
| 10997 | src1_ggml->data = y; |
| 10998 | tensor_ggml->data = d_chk; |
| 10999 | |
| 11000 | ggml_cgraph * cgraph = ggml_new_graph(ggml_ctx); |
| 11001 | ggml_build_forward_expand(cgraph, tensor_ggml); |
| 11002 | |
| 11003 | ggml_graph_compute_with_ctx(ggml_ctx, cgraph, 1); |
| 11004 | |
| 11005 | ggml_free(ggml_ctx); |
| 11006 | |
| 11007 | double avg_err = 0.0; |
| 11008 | int first_err_n = -1; |
| 11009 | int first_err_m = -1; |
| 11010 | int first_err_b = -1; |
| 11011 | |
| 11012 | for (size_t i = 0; i < m*n*batch; i++) { |
| 11013 | double err = std::fabs(d[i] - d_chk[i]); |
| 11014 | avg_err += err; |
| 11015 | |
| 11016 | if ((err > 0.05f || std::isnan(err)) && first_err_n == -1) { |
| 11017 | first_err_b = i / (m * n); |
| 11018 | first_err_n = (i % (m * n)) / m; |
| 11019 | first_err_m = (i % (m * n)) % m; |
| 11020 | } |
| 11021 | } |
| 11022 | |
| 11023 | avg_err /= m * n; |
| 11024 | |
| 11025 | double tflops = 2.0*m*n*k*batch*num_it / (time_ms / 1000.0) / (1000.0*1000.0*1000.0*1000.0); |
| 11026 | |
| 11027 | std::cerr << "TEST dequant matmul " << shname; |
| 11028 | if (mmq) { |
| 11029 | std::cerr << " mmq" ; |
| 11030 | } |
| 11031 | std::cerr << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time_ms / num_it << "ms " << tflops << " TFLOPS avg_err=" << avg_err << std::endl; |
| 11032 | |
| 11033 | if (avg_err > 0.01 || std::isnan(avg_err)) { |
| 11034 | std::cerr << "m = " << first_err_m << " n = " << first_err_n << " b = " << first_err_b << std::endl; |
| 11035 | std::cerr << "Actual result: " << std::endl << std::endl; |
| 11036 | ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); |
| 11037 | std::cerr << std::endl; |
| 11038 | std::cerr << "Expected result: " << std::endl << std::endl; |
| 11039 | ggml_vk_print_matrix_area(d_chk, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); |
| 11040 | |
| 11041 | std::cerr << "src0: " << std::endl << std::endl; |
| 11042 | ggml_vk_print_matrix_area(x, GGML_TYPE_F32, k, m, first_err_m, first_err_n, first_err_b); |
| 11043 | std::cerr << std::endl; |
| 11044 | std::cerr << "src1: " << std::endl << std::endl; |
| 11045 | ggml_vk_print_matrix_area(y, GGML_TYPE_F32, k, n, first_err_m, first_err_n, first_err_b); |
| 11046 | |
| 11047 | if (split_k > 1) { |
| 11048 | float * split_k_buf = (float *) malloc(sizeof(float) * d_ne * split_k); |
| 11049 | ggml_vk_buffer_read(ctx->prealloc_split_k, 0, split_k_buf, sizeof(float) * d_ne * split_k); |
| 11050 | |
| 11051 | std::cerr << "d_buf0: " << std::endl << std::endl; |
| 11052 | ggml_vk_print_matrix_area(split_k_buf, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); |
| 11053 | |
| 11054 | std::cerr << "d_buf1: " << std::endl << std::endl; |
| 11055 | ggml_vk_print_matrix_area(split_k_buf + d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); |
| 11056 | |
| 11057 | std::cerr << "d_buf2: " << std::endl << std::endl; |
| 11058 | ggml_vk_print_matrix_area(split_k_buf + 2 * d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); |
| 11059 | |
| 11060 | std::cerr << "d_buf3: " << std::endl << std::endl; |
| 11061 | ggml_vk_print_matrix_area(split_k_buf + 3 * d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); |
| 11062 | |
| 11063 | free(split_k_buf); |
| 11064 | } |
| 11065 | } |
| 11066 | |
| 11067 | ggml_vk_destroy_buffer(qx_buf); |
| 11068 | ggml_vk_destroy_buffer(y_buf); |
| 11069 | ggml_vk_destroy_buffer(qy_buf); |
| 11070 | ggml_vk_destroy_buffer(d_buf); |
| 11071 | |
| 11072 | free(x); |
| 11073 | free(qx); |
| 11074 | free(y); |
| 11075 | free(d); |
| 11076 | free(d_chk); |
| 11077 | } |
| 11078 | #endif |
| 11079 | |
| 11080 | static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx, vk_context subctx) { |
| 11081 | #if defined(GGML_VULKAN_RUN_TESTS) |
| 11082 | const std::vector<size_t> vals { |
| 11083 | 512, 512, 128, |
| 11084 | 128, 512, 512, |
| 11085 | 4096, 512, 4096, |
| 11086 | 11008, 512, 4096, |
| 11087 | 4096, 512, 11008, |
| 11088 | 32000, 512, 4096, |
| 11089 | 8, 8, 8, |
| 11090 | 100, 46, 576, |
| 11091 | 623, 111, 128, |
| 11092 | 100, 46, 558, |
| 11093 | 512, 1, 256, |
| 11094 | 128, 110, 622, |
| 11095 | 511, 511, 127, |
| 11096 | 511, 511, 7, |
| 11097 | 511, 511, 17, |
| 11098 | 49, 49, 128, |
| 11099 | 128, 49, 49, |
| 11100 | 4096, 49, 4096, |
| 11101 | }; |
| 11102 | const size_t num_it = 100; |
| 11103 | |
| 11104 | ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q4_0); |
| 11105 | ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q4_0); |
| 11106 | ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q4_0); |
| 11107 | |
| 11108 | ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q4_0, true); |
| 11109 | ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q4_0, true); |
| 11110 | ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q4_0, true); |
| 11111 | |
| 11112 | ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q8_0); |
| 11113 | ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q8_0); |
| 11114 | ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q8_0); |
| 11115 | |
| 11116 | ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q8_0, true); |
| 11117 | ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q8_0, true); |
| 11118 | ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q8_0, true); |
| 11119 | |
| 11120 | abort(); |
| 11121 | |
| 11122 | for (size_t i = 0; i < vals.size(); i += 3) { |
| 11123 | ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0); |
| 11124 | ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 1); |
| 11125 | ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 2); |
| 11126 | std::cerr << '\n'; |
| 11127 | ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 0); |
| 11128 | ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 1); |
| 11129 | ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 2); |
| 11130 | std::cerr << '\n'; |
| 11131 | ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 0); |
| 11132 | ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 1); |
| 11133 | ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 2); |
| 11134 | std::cerr << '\n' << std::endl; |
| 11135 | |
| 11136 | if (vals[i + 2] % 32 == 0) { |
| 11137 | ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0, GGML_TYPE_Q4_0); |
| 11138 | ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 1, GGML_TYPE_Q4_0); |
| 11139 | ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 2, GGML_TYPE_Q4_0); |
| 11140 | std::cerr << '\n'; |
| 11141 | ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 0, GGML_TYPE_Q4_0); |
| 11142 | ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 1, GGML_TYPE_Q4_0); |
| 11143 | ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 2, GGML_TYPE_Q4_0); |
| 11144 | std::cerr << '\n'; |
| 11145 | ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 0, GGML_TYPE_Q4_0); |
| 11146 | ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 1, GGML_TYPE_Q4_0); |
| 11147 | ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 2, GGML_TYPE_Q4_0); |
| 11148 | std::cerr << '\n' << std::endl; |
| 11149 | } |
| 11150 | |
| 11151 | if (vals[i + 2] % 256 == 0) { |
| 11152 | ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0, GGML_TYPE_Q4_K); |
| 11153 | ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 1, GGML_TYPE_Q4_K); |
| 11154 | ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 2, GGML_TYPE_Q4_K); |
| 11155 | std::cerr << '\n'; |
| 11156 | ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 0, GGML_TYPE_Q4_K); |
| 11157 | ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 1, GGML_TYPE_Q4_K); |
| 11158 | ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 2, GGML_TYPE_Q4_K); |
| 11159 | std::cerr << '\n'; |
| 11160 | ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 0, GGML_TYPE_Q4_K); |
| 11161 | ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 1, GGML_TYPE_Q4_K); |
| 11162 | ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 2, GGML_TYPE_Q4_K); |
| 11163 | std::cerr << '\n' << std::endl; |
| 11164 | } |
| 11165 | } |
| 11166 | |
| 11167 | GGML_ABORT("fatal error" ); |
| 11168 | #endif |
| 11169 | |
| 11170 | if (subctx) { |
| 11171 | // Submit and wait for any pending work before reallocating the buffers |
| 11172 | ggml_vk_ctx_end(ctx&: subctx); |
| 11173 | ggml_vk_submit(ctx&: subctx, fence: ctx->fence); |
| 11174 | ggml_vk_wait_for_fence(ctx); |
| 11175 | ggml_vk_ctx_begin(device&: ctx->device, subctx); |
| 11176 | } |
| 11177 | |
| 11178 | if (ctx->prealloc_x == nullptr || (ctx->prealloc_size_x > 0 && ctx->prealloc_x->size < ctx->prealloc_size_x)) { |
| 11179 | VK_LOG_MEMORY("ggml_vk_preallocate_buffers(x_size: " << ctx->prealloc_size_x << ")" ); |
| 11180 | // Resize buffer |
| 11181 | if (ctx->prealloc_x != nullptr) { |
| 11182 | ggml_vk_destroy_buffer(buf&: ctx->prealloc_x); |
| 11183 | } |
| 11184 | ctx->prealloc_x = ggml_vk_create_buffer_device(device&: ctx->device, size: ctx->prealloc_size_x); |
| 11185 | } |
| 11186 | if (ctx->prealloc_y == nullptr || (ctx->prealloc_size_y > 0 && ctx->prealloc_y->size < ctx->prealloc_size_y)) { |
| 11187 | VK_LOG_MEMORY("ggml_vk_preallocate_buffers(y_size: " << ctx->prealloc_size_y << ")" ); |
| 11188 | // Resize buffer |
| 11189 | if (ctx->prealloc_y != nullptr) { |
| 11190 | ggml_vk_destroy_buffer(buf&: ctx->prealloc_y); |
| 11191 | } |
| 11192 | ctx->prealloc_y = ggml_vk_create_buffer_device(device&: ctx->device, size: ctx->prealloc_size_y); |
| 11193 | } |
| 11194 | if (ctx->prealloc_split_k == nullptr || (ctx->prealloc_size_split_k > 0 && ctx->prealloc_split_k->size < ctx->prealloc_size_split_k)) { |
| 11195 | VK_LOG_MEMORY("ggml_vk_preallocate_buffers(split_k_size: " << ctx->prealloc_size_split_k << ")" ); |
| 11196 | // Resize buffer |
| 11197 | if (ctx->prealloc_split_k != nullptr) { |
| 11198 | ggml_vk_destroy_buffer(buf&: ctx->prealloc_split_k); |
| 11199 | } |
| 11200 | ctx->prealloc_split_k = ggml_vk_create_buffer_device(device&: ctx->device, size: ctx->prealloc_size_split_k); |
| 11201 | } |
| 11202 | if (ctx->prealloc_add_rms_partials == nullptr || (ctx->prealloc_size_add_rms_partials > 0 && ctx->prealloc_add_rms_partials->size < ctx->prealloc_size_add_rms_partials)) { |
| 11203 | VK_LOG_MEMORY("ggml_vk_preallocate_buffers(add_partials_size: " << ctx->prealloc_add_rms_partials << ")" ); |
| 11204 | // Resize buffer |
| 11205 | if (ctx->prealloc_add_rms_partials != nullptr) { |
| 11206 | ggml_vk_destroy_buffer(buf&: ctx->prealloc_add_rms_partials); |
| 11207 | } |
| 11208 | ctx->prealloc_add_rms_partials = ggml_vk_create_buffer_device(device&: ctx->device, size: ctx->prealloc_size_add_rms_partials); |
| 11209 | } |
| 11210 | } |
| 11211 | |
| 11212 | static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_cgraph * cgraph, ggml_tensor* tensor, int tensor_idx, bool use_fence, bool almost_ready); |
| 11213 | |
| 11214 | // Returns true if node has enqueued work into the queue, false otherwise |
| 11215 | // If submit is true the current all operations queued so far are being submitted to Vulkan to overlap cmdlist creation and GPU execution. |
| 11216 | static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int node_idx, ggml_tensor *node_begin, int node_idx_begin, bool last_node, bool almost_ready, bool submit){ |
| 11217 | ggml_tensor * node = cgraph->nodes[node_idx]; |
| 11218 | if (ggml_is_empty(tensor: node) || !node->buffer) { |
| 11219 | return false; |
| 11220 | } |
| 11221 | |
| 11222 | VK_LOG_DEBUG("ggml_vk_build_graph(" << node << ", " << ggml_op_name(node->op) << ")" ); |
| 11223 | ctx->semaphore_idx = 0; |
| 11224 | |
| 11225 | ggml_tensor * src0 = node->src[0]; |
| 11226 | ggml_tensor * src1 = node->src[1]; |
| 11227 | ggml_tensor * src2 = node->src[2]; |
| 11228 | ggml_tensor * src3 = node->src[3]; |
| 11229 | |
| 11230 | switch (node->op) { |
| 11231 | // Return on empty ops to avoid generating a compute_ctx and setting exit_tensor |
| 11232 | case GGML_OP_RESHAPE: |
| 11233 | case GGML_OP_VIEW: |
| 11234 | case GGML_OP_PERMUTE: |
| 11235 | case GGML_OP_TRANSPOSE: |
| 11236 | case GGML_OP_NONE: |
| 11237 | return false; |
| 11238 | case GGML_OP_UNARY: |
| 11239 | switch (ggml_get_unary_op(tensor: node)) { |
| 11240 | case GGML_UNARY_OP_EXP: |
| 11241 | case GGML_UNARY_OP_SILU: |
| 11242 | case GGML_UNARY_OP_GELU: |
| 11243 | case GGML_UNARY_OP_GELU_ERF: |
| 11244 | case GGML_UNARY_OP_GELU_QUICK: |
| 11245 | case GGML_UNARY_OP_RELU: |
| 11246 | case GGML_UNARY_OP_TANH: |
| 11247 | case GGML_UNARY_OP_SIGMOID: |
| 11248 | case GGML_UNARY_OP_HARDSIGMOID: |
| 11249 | case GGML_UNARY_OP_HARDSWISH: |
| 11250 | break; |
| 11251 | default: |
| 11252 | return false; |
| 11253 | } |
| 11254 | break; |
| 11255 | case GGML_OP_GLU: |
| 11256 | switch (ggml_get_glu_op(tensor: node)) { |
| 11257 | case GGML_GLU_OP_GEGLU: |
| 11258 | case GGML_GLU_OP_REGLU: |
| 11259 | case GGML_GLU_OP_SWIGLU: |
| 11260 | case GGML_GLU_OP_SWIGLU_OAI: |
| 11261 | case GGML_GLU_OP_GEGLU_ERF: |
| 11262 | case GGML_GLU_OP_GEGLU_QUICK: |
| 11263 | break; |
| 11264 | default: |
| 11265 | return false; |
| 11266 | } |
| 11267 | break; |
| 11268 | case GGML_OP_ADD: |
| 11269 | { |
| 11270 | int next_node_idx = node_idx + 1 + ctx->num_additional_fused_ops; |
| 11271 | if (next_node_idx < cgraph->n_nodes && |
| 11272 | cgraph->nodes[next_node_idx]->op == GGML_OP_RMS_NORM && |
| 11273 | cgraph->nodes[next_node_idx]->src[0] == cgraph->nodes[next_node_idx - 1] && |
| 11274 | ggml_nrows(tensor: cgraph->nodes[next_node_idx]) == 1 && |
| 11275 | ctx->device->add_rms_fusion) { |
| 11276 | uint32_t size = ggml_vk_rms_partials_size(ctx, node: cgraph->nodes[node_idx]); |
| 11277 | ctx->do_add_rms_partials_offset_calculation = true; |
| 11278 | if (ctx->prealloc_size_add_rms_partials_offset + size <= ctx->prealloc_size_add_rms_partials) { |
| 11279 | ctx->do_add_rms_partials = true; |
| 11280 | } |
| 11281 | } |
| 11282 | } break; |
| 11283 | case GGML_OP_REPEAT: |
| 11284 | case GGML_OP_REPEAT_BACK: |
| 11285 | case GGML_OP_GET_ROWS: |
| 11286 | case GGML_OP_ADD_ID: |
| 11287 | case GGML_OP_ACC: |
| 11288 | case GGML_OP_SUB: |
| 11289 | case GGML_OP_MUL: |
| 11290 | case GGML_OP_DIV: |
| 11291 | case GGML_OP_CONCAT: |
| 11292 | case GGML_OP_UPSCALE: |
| 11293 | case GGML_OP_SCALE: |
| 11294 | case GGML_OP_SQR: |
| 11295 | case GGML_OP_SQRT: |
| 11296 | case GGML_OP_SIN: |
| 11297 | case GGML_OP_COS: |
| 11298 | case GGML_OP_CLAMP: |
| 11299 | case GGML_OP_PAD: |
| 11300 | case GGML_OP_ROLL: |
| 11301 | case GGML_OP_CPY: |
| 11302 | case GGML_OP_SET_ROWS: |
| 11303 | case GGML_OP_CONT: |
| 11304 | case GGML_OP_DUP: |
| 11305 | case GGML_OP_SILU_BACK: |
| 11306 | case GGML_OP_NORM: |
| 11307 | case GGML_OP_GROUP_NORM: |
| 11308 | case GGML_OP_RMS_NORM: |
| 11309 | case GGML_OP_RMS_NORM_BACK: |
| 11310 | case GGML_OP_L2_NORM: |
| 11311 | case GGML_OP_DIAG_MASK_INF: |
| 11312 | case GGML_OP_SOFT_MAX: |
| 11313 | case GGML_OP_SOFT_MAX_BACK: |
| 11314 | case GGML_OP_ROPE: |
| 11315 | case GGML_OP_ROPE_BACK: |
| 11316 | case GGML_OP_MUL_MAT: |
| 11317 | case GGML_OP_MUL_MAT_ID: |
| 11318 | case GGML_OP_ARGSORT: |
| 11319 | case GGML_OP_SUM: |
| 11320 | case GGML_OP_SUM_ROWS: |
| 11321 | case GGML_OP_MEAN: |
| 11322 | case GGML_OP_ARGMAX: |
| 11323 | case GGML_OP_COUNT_EQUAL: |
| 11324 | case GGML_OP_IM2COL: |
| 11325 | case GGML_OP_IM2COL_3D: |
| 11326 | case GGML_OP_TIMESTEP_EMBEDDING: |
| 11327 | case GGML_OP_CONV_TRANSPOSE_1D: |
| 11328 | case GGML_OP_POOL_2D: |
| 11329 | case GGML_OP_CONV_2D: |
| 11330 | case GGML_OP_CONV_TRANSPOSE_2D: |
| 11331 | case GGML_OP_CONV_2D_DW: |
| 11332 | case GGML_OP_RWKV_WKV6: |
| 11333 | case GGML_OP_RWKV_WKV7: |
| 11334 | case GGML_OP_SSM_SCAN: |
| 11335 | case GGML_OP_SSM_CONV: |
| 11336 | case GGML_OP_LEAKY_RELU: |
| 11337 | case GGML_OP_FLASH_ATTN_EXT: |
| 11338 | case GGML_OP_OPT_STEP_ADAMW: |
| 11339 | case GGML_OP_OPT_STEP_SGD: |
| 11340 | break; |
| 11341 | default: |
| 11342 | std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(op: node->op) << std::endl; |
| 11343 | GGML_ABORT("fatal error" ); |
| 11344 | } |
| 11345 | |
| 11346 | vk_context compute_ctx; |
| 11347 | |
| 11348 | if (ctx->compute_ctx.expired()) { |
| 11349 | compute_ctx = ggml_vk_create_context(ctx, p&: ctx->compute_cmd_pool); |
| 11350 | ctx->compute_ctx = compute_ctx; |
| 11351 | ggml_vk_ctx_begin(device&: ctx->device, subctx&: compute_ctx); |
| 11352 | } else { |
| 11353 | compute_ctx = ctx->compute_ctx.lock(); |
| 11354 | } |
| 11355 | |
| 11356 | { |
| 11357 | // This logic detects dependencies between modes in the graph and calls ggml_vk_sync_buffers |
| 11358 | // to synchronize them. This handles most "normal" synchronization when computing the graph, and when |
| 11359 | // there is no auxiliary memory use, it shouldn't be necessary to call ggml_vk_sync_buffers |
| 11360 | // outside of this logic. When a node uses one of the prealloc buffers for something like |
| 11361 | // dequantization or split_k, additional synchronization is needed between those passes. |
| 11362 | bool need_sync = false; |
| 11363 | |
| 11364 | // Check whether "node" requires synchronization. The node requires synchronization if it |
| 11365 | // overlaps in memory with another unsynchronized node and at least one of them is a write. |
| 11366 | // Destination nodes are checked against both the written/read lists. Source nodes are only |
| 11367 | // checked against the written list. Two nodes overlap in memory if they come from the same |
| 11368 | // buffer and the tensor or view ranges overlap. |
| 11369 | auto const &overlaps_unsynced = [&](const ggml_tensor *node, const std::vector<const ggml_tensor *> &unsynced_nodes) -> bool { |
| 11370 | if (unsynced_nodes.size() == 0) { |
| 11371 | return false; |
| 11372 | } |
| 11373 | auto n_base = vk_tensor_offset(tensor: node) + node->view_offs; |
| 11374 | auto n_size = ggml_nbytes(tensor: node); |
| 11375 | ggml_backend_vk_buffer_context * a_buf_ctx = (ggml_backend_vk_buffer_context *)node->buffer->context; |
| 11376 | vk_buffer a_buf = a_buf_ctx->dev_buffer; |
| 11377 | for (auto &other : unsynced_nodes) { |
| 11378 | ggml_backend_vk_buffer_context * o_buf_ctx = (ggml_backend_vk_buffer_context *)other->buffer->context; |
| 11379 | vk_buffer o_buf = o_buf_ctx->dev_buffer; |
| 11380 | if (a_buf == o_buf) { |
| 11381 | auto o_base = vk_tensor_offset(tensor: other) + other->view_offs; |
| 11382 | auto o_size = ggml_nbytes(tensor: other); |
| 11383 | |
| 11384 | if ((o_base <= n_base && n_base < o_base + o_size) || |
| 11385 | (n_base <= o_base && o_base < n_base + n_size)) { |
| 11386 | return true; |
| 11387 | } |
| 11388 | } |
| 11389 | } |
| 11390 | return false; |
| 11391 | }; |
| 11392 | |
| 11393 | // For all fused ops, check if the destination node or any of the source |
| 11394 | // nodes require synchronization. |
| 11395 | for (int32_t i = 0; i < ctx->num_additional_fused_ops + 1 && !need_sync; ++i) { |
| 11396 | const ggml_tensor *cur_node = cgraph->nodes[node_idx + i]; |
| 11397 | // If the node actually writes to memory, then check if it needs to sync |
| 11398 | if (ctx->fused_ops_write_mask & (1 << i)) { |
| 11399 | if (overlaps_unsynced(cur_node, ctx->unsynced_nodes_read) || overlaps_unsynced(cur_node, ctx->unsynced_nodes_written)) { |
| 11400 | need_sync = true; |
| 11401 | break; |
| 11402 | } |
| 11403 | } |
| 11404 | for (uint32_t j = 0; j < GGML_MAX_SRC; ++j) { |
| 11405 | if (!cur_node->src[j]) { |
| 11406 | continue; |
| 11407 | } |
| 11408 | if (overlaps_unsynced(cur_node->src[j], ctx->unsynced_nodes_written)) { |
| 11409 | need_sync = true; |
| 11410 | break; |
| 11411 | } |
| 11412 | } |
| 11413 | } |
| 11414 | |
| 11415 | #define ENABLE_SYNC_LOGGING 0 |
| 11416 | |
| 11417 | if (need_sync) { |
| 11418 | #if ENABLE_SYNC_LOGGING |
| 11419 | std::cerr << "sync" << std::endl; |
| 11420 | #endif |
| 11421 | ctx->unsynced_nodes_written.clear(); |
| 11422 | ctx->unsynced_nodes_read.clear(); |
| 11423 | ggml_vk_sync_buffers(ctx, subctx&: compute_ctx); |
| 11424 | } |
| 11425 | // Add all fused nodes to the unsynchronized lists. |
| 11426 | for (int32_t i = 0; i < ctx->num_additional_fused_ops + 1; ++i) { |
| 11427 | const ggml_tensor *cur_node = cgraph->nodes[node_idx + i]; |
| 11428 | // Multiple outputs could be written, e.g. in topk_moe. Add them all to the list. |
| 11429 | if (ctx->fused_ops_write_mask & (1 << i)) { |
| 11430 | ctx->unsynced_nodes_written.push_back(x: cur_node); |
| 11431 | } |
| 11432 | for (uint32_t j = 0; j < GGML_MAX_SRC; ++j) { |
| 11433 | if (!cur_node->src[j]) { |
| 11434 | continue; |
| 11435 | } |
| 11436 | ctx->unsynced_nodes_read.push_back(x: cur_node->src[j]); |
| 11437 | } |
| 11438 | } |
| 11439 | } |
| 11440 | #if ENABLE_SYNC_LOGGING |
| 11441 | for (int i = 0; i < ctx->num_additional_fused_ops + 1; ++i) { |
| 11442 | auto *n = cgraph->nodes[node_idx + i]; |
| 11443 | std::cerr << node_idx + i << " " << ggml_op_name(n->op) << " " << n->name; |
| 11444 | if (n->op == GGML_OP_GLU) { |
| 11445 | std::cerr << " " << ggml_glu_op_name(ggml_get_glu_op(n)) << " " << (n->src[1] ? "split" : "single" ) << " " ; |
| 11446 | } |
| 11447 | if (n->op == GGML_OP_ROPE) { |
| 11448 | const int mode = ((const int32_t *) n->op_params)[2]; |
| 11449 | std::cerr << " rope mode: " << mode; |
| 11450 | } |
| 11451 | std::cerr << std::endl; |
| 11452 | } |
| 11453 | #endif |
| 11454 | |
| 11455 | switch (node->op) { |
| 11456 | case GGML_OP_REPEAT: |
| 11457 | ggml_vk_repeat(ctx, subctx&: compute_ctx, src0, dst: node); |
| 11458 | |
| 11459 | break; |
| 11460 | case GGML_OP_REPEAT_BACK: |
| 11461 | ggml_vk_repeat_back(ctx, subctx&: compute_ctx, src0, dst: node); |
| 11462 | |
| 11463 | break; |
| 11464 | case GGML_OP_ACC: |
| 11465 | ggml_vk_acc(ctx, subctx&: compute_ctx, src0, src1, dst: node); |
| 11466 | |
| 11467 | break; |
| 11468 | case GGML_OP_GET_ROWS: |
| 11469 | ggml_vk_get_rows(ctx, subctx&: compute_ctx, src0, src1, dst: node); |
| 11470 | |
| 11471 | break; |
| 11472 | case GGML_OP_ADD: |
| 11473 | if (ctx->num_additional_fused_ops) { |
| 11474 | ggml_vk_multi_add(ctx, subctx&: compute_ctx, cgraph, node_idx); |
| 11475 | } else { |
| 11476 | ggml_vk_add(ctx, subctx&: compute_ctx, src0, src1, dst: node); |
| 11477 | } |
| 11478 | break; |
| 11479 | case GGML_OP_SUB: |
| 11480 | ggml_vk_sub(ctx, subctx&: compute_ctx, src0, src1, dst: node); |
| 11481 | |
| 11482 | break; |
| 11483 | case GGML_OP_MUL: |
| 11484 | ggml_vk_mul(ctx, subctx&: compute_ctx, src0, src1, dst: node); |
| 11485 | |
| 11486 | break; |
| 11487 | case GGML_OP_DIV: |
| 11488 | ggml_vk_div(ctx, subctx&: compute_ctx, src0, src1, dst: node); |
| 11489 | |
| 11490 | break; |
| 11491 | case GGML_OP_ADD_ID: |
| 11492 | ggml_vk_add_id(ctx, subctx&: compute_ctx, src0, src1, src2, dst: node); |
| 11493 | |
| 11494 | break; |
| 11495 | case GGML_OP_CONCAT: |
| 11496 | ggml_vk_concat(ctx, subctx&: compute_ctx, src0, src1, dst: node); |
| 11497 | |
| 11498 | break; |
| 11499 | case GGML_OP_UPSCALE: |
| 11500 | ggml_vk_upscale(ctx, subctx&: compute_ctx, src0, dst: node); |
| 11501 | |
| 11502 | break; |
| 11503 | case GGML_OP_SCALE: |
| 11504 | ggml_vk_scale(ctx, subctx&: compute_ctx, src0, dst: node); |
| 11505 | |
| 11506 | break; |
| 11507 | case GGML_OP_SQR: |
| 11508 | ggml_vk_sqr(ctx, subctx&: compute_ctx, src0, dst: node); |
| 11509 | |
| 11510 | break; |
| 11511 | case GGML_OP_SQRT: |
| 11512 | ggml_vk_sqrt(ctx, subctx&: compute_ctx, src0, dst: node); |
| 11513 | |
| 11514 | break; |
| 11515 | case GGML_OP_SIN: |
| 11516 | ggml_vk_sin(ctx, subctx&: compute_ctx, src0, dst: node); |
| 11517 | |
| 11518 | break; |
| 11519 | case GGML_OP_COS: |
| 11520 | ggml_vk_cos(ctx, subctx&: compute_ctx, src0, dst: node); |
| 11521 | |
| 11522 | break; |
| 11523 | case GGML_OP_CLAMP: |
| 11524 | ggml_vk_clamp(ctx, subctx&: compute_ctx, src0, dst: node); |
| 11525 | |
| 11526 | break; |
| 11527 | case GGML_OP_PAD: |
| 11528 | ggml_vk_pad(ctx, subctx&: compute_ctx, src0, dst: node); |
| 11529 | |
| 11530 | break; |
| 11531 | case GGML_OP_ROLL: |
| 11532 | ggml_vk_roll(ctx, subctx&: compute_ctx, src0, dst: node); |
| 11533 | |
| 11534 | break; |
| 11535 | case GGML_OP_CPY: |
| 11536 | case GGML_OP_CONT: |
| 11537 | case GGML_OP_DUP: |
| 11538 | ggml_vk_cpy(ctx, subctx&: compute_ctx, src0, dst: node); |
| 11539 | |
| 11540 | break; |
| 11541 | case GGML_OP_SET_ROWS: |
| 11542 | ggml_vk_set_rows(ctx, subctx&: compute_ctx, src0, src1, dst: node); |
| 11543 | |
| 11544 | break; |
| 11545 | case GGML_OP_SILU_BACK: |
| 11546 | ggml_vk_silu_back(ctx, subctx&: compute_ctx, src0, src1, dst: node); |
| 11547 | |
| 11548 | break; |
| 11549 | case GGML_OP_NORM: |
| 11550 | ggml_vk_norm(ctx, subctx&: compute_ctx, src0, dst: node); |
| 11551 | |
| 11552 | break; |
| 11553 | case GGML_OP_GROUP_NORM: |
| 11554 | ggml_vk_group_norm(ctx, subctx&: compute_ctx, src0, dst: node); |
| 11555 | |
| 11556 | break; |
| 11557 | case GGML_OP_RMS_NORM: |
| 11558 | ggml_vk_rms_norm(ctx, subctx&: compute_ctx, cgraph, node_idx, op_params: (float *)node->op_params); |
| 11559 | break; |
| 11560 | case GGML_OP_RMS_NORM_BACK: |
| 11561 | ggml_vk_rms_norm_back(ctx, subctx&: compute_ctx, src0, src1, dst: node); |
| 11562 | |
| 11563 | break; |
| 11564 | case GGML_OP_L2_NORM: |
| 11565 | ggml_vk_l2_norm(ctx, subctx&: compute_ctx, src0, dst: node); |
| 11566 | |
| 11567 | break; |
| 11568 | case GGML_OP_UNARY: |
| 11569 | switch (ggml_get_unary_op(tensor: node)) { |
| 11570 | case GGML_UNARY_OP_EXP: |
| 11571 | case GGML_UNARY_OP_SILU: |
| 11572 | case GGML_UNARY_OP_GELU: |
| 11573 | case GGML_UNARY_OP_GELU_ERF: |
| 11574 | case GGML_UNARY_OP_GELU_QUICK: |
| 11575 | case GGML_UNARY_OP_RELU: |
| 11576 | case GGML_UNARY_OP_TANH: |
| 11577 | case GGML_UNARY_OP_SIGMOID: |
| 11578 | case GGML_UNARY_OP_HARDSIGMOID: |
| 11579 | case GGML_UNARY_OP_HARDSWISH: |
| 11580 | ggml_vk_unary(ctx, subctx&: compute_ctx, src0, dst: node); |
| 11581 | break; |
| 11582 | default: |
| 11583 | return false; |
| 11584 | } |
| 11585 | break; |
| 11586 | case GGML_OP_GLU: |
| 11587 | switch (ggml_get_glu_op(tensor: node)) { |
| 11588 | case GGML_GLU_OP_GEGLU: |
| 11589 | case GGML_GLU_OP_REGLU: |
| 11590 | case GGML_GLU_OP_SWIGLU: |
| 11591 | case GGML_GLU_OP_SWIGLU_OAI: |
| 11592 | case GGML_GLU_OP_GEGLU_ERF: |
| 11593 | case GGML_GLU_OP_GEGLU_QUICK: |
| 11594 | ggml_vk_glu(ctx, subctx&: compute_ctx, src0, src1, dst: node); |
| 11595 | break; |
| 11596 | default: |
| 11597 | return false; |
| 11598 | } |
| 11599 | break; |
| 11600 | case GGML_OP_DIAG_MASK_INF: |
| 11601 | ggml_vk_diag_mask_inf(ctx, subctx&: compute_ctx, src0, dst: node); |
| 11602 | |
| 11603 | break; |
| 11604 | case GGML_OP_SOFT_MAX: |
| 11605 | if (ctx->num_additional_fused_ops) { |
| 11606 | ggml_vk_topk_moe(ctx, subctx&: compute_ctx, cgraph, node_idx); |
| 11607 | } else { |
| 11608 | ggml_vk_soft_max(ctx, subctx&: compute_ctx, src0, src1, src2, dst: node); |
| 11609 | } |
| 11610 | |
| 11611 | break; |
| 11612 | case GGML_OP_SOFT_MAX_BACK: |
| 11613 | ggml_vk_soft_max_back(ctx, subctx&: compute_ctx, src0, src1, dst: node); |
| 11614 | |
| 11615 | break; |
| 11616 | case GGML_OP_ROPE: |
| 11617 | ggml_vk_rope(ctx, subctx&: compute_ctx, cgraph, node_idx, backprop: false); |
| 11618 | |
| 11619 | break; |
| 11620 | case GGML_OP_ROPE_BACK: |
| 11621 | ggml_vk_rope(ctx, subctx&: compute_ctx, cgraph, node_idx, backprop: true); |
| 11622 | |
| 11623 | break; |
| 11624 | case GGML_OP_ARGSORT: |
| 11625 | if (ctx->num_additional_fused_ops) { |
| 11626 | ggml_vk_topk_moe(ctx, subctx&: compute_ctx, cgraph, node_idx); |
| 11627 | } else { |
| 11628 | ggml_vk_argsort(ctx, subctx&: compute_ctx, src0, dst: node); |
| 11629 | } |
| 11630 | |
| 11631 | break; |
| 11632 | case GGML_OP_SUM: |
| 11633 | ggml_vk_sum(ctx, subctx&: compute_ctx, src0, dst: node); |
| 11634 | |
| 11635 | break; |
| 11636 | case GGML_OP_SUM_ROWS: |
| 11637 | ggml_vk_sum_rows(ctx, subctx&: compute_ctx, src0, dst: node); |
| 11638 | |
| 11639 | break; |
| 11640 | case GGML_OP_MEAN: |
| 11641 | ggml_vk_mean(ctx, subctx&: compute_ctx, src0, dst: node); |
| 11642 | |
| 11643 | break; |
| 11644 | case GGML_OP_ARGMAX: |
| 11645 | ggml_vk_argmax(ctx, subctx&: compute_ctx, src0, dst: node); |
| 11646 | |
| 11647 | break; |
| 11648 | case GGML_OP_COUNT_EQUAL: |
| 11649 | ggml_vk_count_equal(ctx, subctx&: compute_ctx, src0, src1, dst: node); |
| 11650 | |
| 11651 | break; |
| 11652 | case GGML_OP_IM2COL: |
| 11653 | ggml_vk_im2col(ctx, subctx&: compute_ctx, src0, src1, dst: node); |
| 11654 | |
| 11655 | break; |
| 11656 | case GGML_OP_IM2COL_3D: |
| 11657 | ggml_vk_im2col_3d(ctx, subctx&: compute_ctx, src0, src1, dst: node); |
| 11658 | |
| 11659 | break; |
| 11660 | case GGML_OP_TIMESTEP_EMBEDDING: |
| 11661 | ggml_vk_timestep_embedding(ctx, subctx&: compute_ctx, src0, dst: node); |
| 11662 | |
| 11663 | break; |
| 11664 | case GGML_OP_CONV_TRANSPOSE_1D: |
| 11665 | ggml_vk_conv_transpose_1d(ctx, subctx&: compute_ctx, src0, src1, dst: node); |
| 11666 | |
| 11667 | break; |
| 11668 | case GGML_OP_POOL_2D: |
| 11669 | ggml_vk_pool_2d(ctx, subctx&: compute_ctx, src0, dst: node); |
| 11670 | |
| 11671 | break; |
| 11672 | case GGML_OP_CONV_2D: |
| 11673 | ggml_vk_conv_2d(ctx, subctx&: compute_ctx, src0, src1, dst: node); |
| 11674 | |
| 11675 | break; |
| 11676 | case GGML_OP_CONV_TRANSPOSE_2D: |
| 11677 | ggml_vk_conv_transpose_2d(ctx, subctx&: compute_ctx, src0, src1, dst: node); |
| 11678 | |
| 11679 | break; |
| 11680 | case GGML_OP_CONV_2D_DW: |
| 11681 | ggml_vk_conv_2d_dw(ctx, subctx&: compute_ctx, src0, src1, dst: node); |
| 11682 | |
| 11683 | break; |
| 11684 | case GGML_OP_LEAKY_RELU: |
| 11685 | ggml_vk_leaky_relu(ctx, subctx&: compute_ctx, src0, dst: node); |
| 11686 | |
| 11687 | break; |
| 11688 | case GGML_OP_MUL_MAT: |
| 11689 | ggml_vk_mul_mat(ctx, subctx&: compute_ctx, cgraph, node_idx); |
| 11690 | |
| 11691 | break; |
| 11692 | case GGML_OP_MUL_MAT_ID: |
| 11693 | ggml_vk_mul_mat_id(ctx, subctx&: compute_ctx, cgraph, node_idx); |
| 11694 | |
| 11695 | break; |
| 11696 | |
| 11697 | case GGML_OP_FLASH_ATTN_EXT: |
| 11698 | ggml_vk_flash_attn(ctx, subctx&: compute_ctx, q: src0, k: src1, v: src2, mask: src3, sinks: node->src[4], dst: node); |
| 11699 | |
| 11700 | break; |
| 11701 | |
| 11702 | case GGML_OP_RWKV_WKV6: |
| 11703 | ggml_vk_rwkv_wkv6(ctx, subctx&: compute_ctx, dst: node); |
| 11704 | |
| 11705 | break; |
| 11706 | |
| 11707 | case GGML_OP_RWKV_WKV7: |
| 11708 | ggml_vk_rwkv_wkv7(ctx, subctx&: compute_ctx, dst: node); |
| 11709 | |
| 11710 | break; |
| 11711 | |
| 11712 | case GGML_OP_SSM_SCAN: |
| 11713 | ggml_vk_ssm_scan(ctx, subctx&: compute_ctx, dst: node); |
| 11714 | |
| 11715 | break; |
| 11716 | |
| 11717 | case GGML_OP_SSM_CONV: |
| 11718 | ggml_vk_ssm_conv(ctx, subctx&: compute_ctx, dst: node); |
| 11719 | |
| 11720 | break; |
| 11721 | |
| 11722 | case GGML_OP_OPT_STEP_ADAMW: |
| 11723 | ggml_vk_opt_step_adamw(ctx, subctx&: compute_ctx, dst: node); |
| 11724 | |
| 11725 | break; |
| 11726 | |
| 11727 | case GGML_OP_OPT_STEP_SGD: |
| 11728 | ggml_vk_opt_step_sgd(ctx, subctx&: compute_ctx, src0, src1, src2, dst: node); |
| 11729 | |
| 11730 | break; |
| 11731 | default: |
| 11732 | return false; |
| 11733 | } |
| 11734 | |
| 11735 | ctx->tensor_ctxs[node_idx] = compute_ctx; |
| 11736 | |
| 11737 | #if defined(GGML_VULKAN_CHECK_RESULTS) |
| 11738 | // Force context reset on each node so that each tensor ends up in its own context |
| 11739 | // and can be run and compared to its CPU equivalent separately |
| 11740 | last_node = true; |
| 11741 | #endif |
| 11742 | |
| 11743 | if (submit || last_node) { |
| 11744 | ggml_vk_ctx_end(ctx&: compute_ctx); |
| 11745 | |
| 11746 | // TODO probably it'd be better to pass a exit_node flag to ggml_vk_compute_forward |
| 11747 | if (last_node) { |
| 11748 | compute_ctx->exit_tensor_idx = node_idx_begin; |
| 11749 | } |
| 11750 | else { |
| 11751 | compute_ctx->exit_tensor_idx = -1; |
| 11752 | } |
| 11753 | |
| 11754 | ctx->compute_ctx.reset(); |
| 11755 | |
| 11756 | bool ok = ggml_vk_compute_forward(ctx, cgraph, tensor: node_begin, tensor_idx: node_idx_begin, use_fence: false, almost_ready); |
| 11757 | if (!ok) { |
| 11758 | if (node->op == GGML_OP_UNARY) { |
| 11759 | std::cerr << __func__ << ": error: op not supported UNARY " << node->name << " (" << ggml_unary_op_name(op: static_cast<ggml_unary_op>(node->op_params[0])) << ")" << std::endl; |
| 11760 | } else if (node->op == GGML_OP_GLU) { |
| 11761 | std::cerr << __func__ << ": error: op not supported GLU " << node->name << " (" << ggml_glu_op_name(op: static_cast<ggml_glu_op>(node->op_params[0])) << ")" << std::endl; |
| 11762 | } else { |
| 11763 | std::cerr << __func__ << ": error: op not supported " << node->name << " (" << ggml_op_name(op: node->op) << ")" << std::endl; |
| 11764 | } |
| 11765 | } |
| 11766 | |
| 11767 | } |
| 11768 | return true; |
| 11769 | } |
| 11770 | |
| 11771 | static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, ggml_tensor * tensor, int tensor_idx, bool use_fence = true, bool almost_ready = false) { |
| 11772 | GGML_UNUSED(cgraph); |
| 11773 | ggml_backend_buffer * buf = nullptr; |
| 11774 | |
| 11775 | switch (tensor->op) { |
| 11776 | case GGML_OP_ADD: |
| 11777 | case GGML_OP_ACC: |
| 11778 | case GGML_OP_GET_ROWS: |
| 11779 | case GGML_OP_SUB: |
| 11780 | case GGML_OP_MUL: |
| 11781 | case GGML_OP_DIV: |
| 11782 | case GGML_OP_ADD_ID: |
| 11783 | case GGML_OP_CONCAT: |
| 11784 | case GGML_OP_UPSCALE: |
| 11785 | case GGML_OP_SCALE: |
| 11786 | case GGML_OP_SQR: |
| 11787 | case GGML_OP_SQRT: |
| 11788 | case GGML_OP_SIN: |
| 11789 | case GGML_OP_COS: |
| 11790 | case GGML_OP_CLAMP: |
| 11791 | case GGML_OP_PAD: |
| 11792 | case GGML_OP_ROLL: |
| 11793 | case GGML_OP_CPY: |
| 11794 | case GGML_OP_SET_ROWS: |
| 11795 | case GGML_OP_CONT: |
| 11796 | case GGML_OP_DUP: |
| 11797 | case GGML_OP_SILU_BACK: |
| 11798 | case GGML_OP_NORM: |
| 11799 | case GGML_OP_GROUP_NORM: |
| 11800 | case GGML_OP_RMS_NORM: |
| 11801 | case GGML_OP_RMS_NORM_BACK: |
| 11802 | case GGML_OP_L2_NORM: |
| 11803 | case GGML_OP_DIAG_MASK_INF: |
| 11804 | case GGML_OP_SOFT_MAX: |
| 11805 | case GGML_OP_SOFT_MAX_BACK: |
| 11806 | case GGML_OP_ROPE: |
| 11807 | case GGML_OP_ROPE_BACK: |
| 11808 | case GGML_OP_RESHAPE: |
| 11809 | case GGML_OP_VIEW: |
| 11810 | case GGML_OP_PERMUTE: |
| 11811 | case GGML_OP_TRANSPOSE: |
| 11812 | case GGML_OP_NONE: |
| 11813 | case GGML_OP_ARGSORT: |
| 11814 | case GGML_OP_SUM: |
| 11815 | case GGML_OP_SUM_ROWS: |
| 11816 | case GGML_OP_MEAN: |
| 11817 | case GGML_OP_ARGMAX: |
| 11818 | case GGML_OP_COUNT_EQUAL: |
| 11819 | case GGML_OP_IM2COL: |
| 11820 | case GGML_OP_IM2COL_3D: |
| 11821 | case GGML_OP_TIMESTEP_EMBEDDING: |
| 11822 | case GGML_OP_CONV_TRANSPOSE_1D: |
| 11823 | case GGML_OP_POOL_2D: |
| 11824 | case GGML_OP_CONV_2D: |
| 11825 | case GGML_OP_CONV_TRANSPOSE_2D: |
| 11826 | case GGML_OP_CONV_2D_DW: |
| 11827 | case GGML_OP_RWKV_WKV6: |
| 11828 | case GGML_OP_RWKV_WKV7: |
| 11829 | case GGML_OP_SSM_SCAN: |
| 11830 | case GGML_OP_SSM_CONV: |
| 11831 | case GGML_OP_LEAKY_RELU: |
| 11832 | case GGML_OP_REPEAT: |
| 11833 | case GGML_OP_REPEAT_BACK: |
| 11834 | case GGML_OP_OPT_STEP_ADAMW: |
| 11835 | case GGML_OP_OPT_STEP_SGD: |
| 11836 | buf = tensor->buffer; |
| 11837 | break; |
| 11838 | case GGML_OP_UNARY: |
| 11839 | switch (ggml_get_unary_op(tensor)) { |
| 11840 | case GGML_UNARY_OP_EXP: |
| 11841 | case GGML_UNARY_OP_SILU: |
| 11842 | case GGML_UNARY_OP_GELU: |
| 11843 | case GGML_UNARY_OP_GELU_ERF: |
| 11844 | case GGML_UNARY_OP_GELU_QUICK: |
| 11845 | case GGML_UNARY_OP_RELU: |
| 11846 | case GGML_UNARY_OP_TANH: |
| 11847 | case GGML_UNARY_OP_SIGMOID: |
| 11848 | case GGML_UNARY_OP_HARDSIGMOID: |
| 11849 | case GGML_UNARY_OP_HARDSWISH: |
| 11850 | buf = tensor->buffer; |
| 11851 | break; |
| 11852 | default: |
| 11853 | return false; |
| 11854 | } |
| 11855 | break; |
| 11856 | case GGML_OP_GLU: |
| 11857 | switch (ggml_get_glu_op(tensor)) { |
| 11858 | case GGML_GLU_OP_GEGLU: |
| 11859 | case GGML_GLU_OP_REGLU: |
| 11860 | case GGML_GLU_OP_SWIGLU: |
| 11861 | case GGML_GLU_OP_SWIGLU_OAI: |
| 11862 | case GGML_GLU_OP_GEGLU_ERF: |
| 11863 | case GGML_GLU_OP_GEGLU_QUICK: |
| 11864 | buf = tensor->buffer; |
| 11865 | break; |
| 11866 | default: |
| 11867 | return false; |
| 11868 | } |
| 11869 | break; |
| 11870 | case GGML_OP_MUL_MAT: |
| 11871 | case GGML_OP_MUL_MAT_ID: |
| 11872 | case GGML_OP_FLASH_ATTN_EXT: |
| 11873 | buf = tensor->buffer; |
| 11874 | |
| 11875 | break; |
| 11876 | default: |
| 11877 | return false; |
| 11878 | } |
| 11879 | |
| 11880 | if (buf == nullptr) { |
| 11881 | return false; |
| 11882 | } |
| 11883 | |
| 11884 | VK_LOG_DEBUG("ggml_vk_compute_forward(" << tensor << ", name=" << tensor->name << ", op=" << ggml_op_name(tensor->op) << ", type=" << tensor->type << ", ne0=" << tensor->ne[0] << ", ne1=" << tensor->ne[1] << ", ne2=" << tensor->ne[2] << ", ne3=" << tensor->ne[3] << ", nb0=" << tensor->nb[0] << ", nb1=" << tensor->nb[1] << ", nb2=" << tensor->nb[2] << ", nb3=" << tensor->nb[3] << ", view_src=" << tensor->view_src << ", view_offs=" << tensor->view_offs << ")" ); |
| 11885 | |
| 11886 | vk_context subctx = ctx->tensor_ctxs[tensor_idx].lock(); |
| 11887 | |
| 11888 | // always wait for the GPU work to be done for the last submit |
| 11889 | if (tensor_idx == subctx->exit_tensor_idx) { |
| 11890 | use_fence = true; |
| 11891 | } |
| 11892 | |
| 11893 | // Only run if ctx hasn't been submitted yet |
| 11894 | if (!subctx->seqs.empty()) { |
| 11895 | #ifdef GGML_VULKAN_CHECK_RESULTS |
| 11896 | ggml_vk_check_results_0(ctx, cgraph, tensor_idx); |
| 11897 | use_fence = true; |
| 11898 | #endif |
| 11899 | |
| 11900 | // Do staging buffer copies |
| 11901 | for (auto& cpy : subctx->in_memcpys) { |
| 11902 | memcpy(dest: cpy.dst, src: cpy.src, n: cpy.n); |
| 11903 | } |
| 11904 | |
| 11905 | for (auto& mset : subctx->memsets) { |
| 11906 | memset(s: mset.dst, c: mset.val, n: mset.n); |
| 11907 | } |
| 11908 | |
| 11909 | if (almost_ready && !ctx->almost_ready_fence_pending && !use_fence) { |
| 11910 | ggml_vk_submit(ctx&: subctx, fence: ctx->almost_ready_fence); |
| 11911 | ctx->almost_ready_fence_pending = true; |
| 11912 | } else { |
| 11913 | ggml_vk_submit(ctx&: subctx, fence: use_fence ? ctx->fence : vk::Fence{}); |
| 11914 | } |
| 11915 | |
| 11916 | if (use_fence) { |
| 11917 | ggml_vk_wait_for_fence(ctx); |
| 11918 | } |
| 11919 | #ifdef GGML_VULKAN_CHECK_RESULTS |
| 11920 | ggml_vk_check_results_1(ctx, cgraph, tensor_idx); |
| 11921 | #endif |
| 11922 | } |
| 11923 | |
| 11924 | if (tensor_idx == subctx->exit_tensor_idx) { |
| 11925 | // Do staging buffer copies |
| 11926 | for (auto& cpy : subctx->out_memcpys) { |
| 11927 | memcpy(dest: cpy.dst, src: cpy.src, n: cpy.n); |
| 11928 | } |
| 11929 | subctx->in_memcpys.clear(); |
| 11930 | subctx->out_memcpys.clear(); |
| 11931 | subctx->memsets.clear(); |
| 11932 | } |
| 11933 | |
| 11934 | return true; |
| 11935 | } |
| 11936 | |
| 11937 | // Clean up after graph processing is done |
| 11938 | static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) { |
| 11939 | VK_LOG_DEBUG("ggml_vk_graph_cleanup()" ); |
| 11940 | ctx->prealloc_y_last_pipeline_used = {}; |
| 11941 | |
| 11942 | ctx->unsynced_nodes_written.clear(); |
| 11943 | ctx->unsynced_nodes_read.clear(); |
| 11944 | ctx->prealloc_x_need_sync = ctx->prealloc_y_need_sync = ctx->prealloc_split_k_need_sync = false; |
| 11945 | |
| 11946 | ggml_vk_command_pool_cleanup(device&: ctx->device, p&: ctx->compute_cmd_pool); |
| 11947 | ggml_vk_command_pool_cleanup(device&: ctx->device, p&: ctx->transfer_cmd_pool); |
| 11948 | |
| 11949 | for (size_t i = 0; i < ctx->gc.semaphores.size(); i++) { |
| 11950 | ctx->device->device.destroySemaphore(semaphore: { ctx->gc.semaphores[i].s }); |
| 11951 | } |
| 11952 | ctx->gc.semaphores.clear(); |
| 11953 | |
| 11954 | for (size_t i = 0; i < ctx->gc.tl_semaphores.size(); i++) { |
| 11955 | ctx->device->device.destroySemaphore(semaphore: { ctx->gc.tl_semaphores[i].s }); |
| 11956 | } |
| 11957 | ctx->gc.tl_semaphores.clear(); |
| 11958 | ctx->semaphore_idx = 0; |
| 11959 | |
| 11960 | ctx->event_idx = 0; |
| 11961 | |
| 11962 | for (auto& event : ctx->gc.events) { |
| 11963 | ctx->device->device.resetEvent(event); |
| 11964 | } |
| 11965 | |
| 11966 | ctx->tensor_ctxs.clear(); |
| 11967 | ctx->gc.contexts.clear(); |
| 11968 | ctx->pipeline_descriptor_set_requirements = 0; |
| 11969 | ctx->descriptor_set_idx = 0; |
| 11970 | } |
| 11971 | |
| 11972 | // Clean up on backend free |
| 11973 | static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) { |
| 11974 | VK_LOG_DEBUG("ggml_vk_cleanup(" << ctx->name << ")" ); |
| 11975 | ggml_vk_graph_cleanup(ctx); |
| 11976 | |
| 11977 | ggml_vk_destroy_buffer(buf&: ctx->prealloc_x); |
| 11978 | ggml_vk_destroy_buffer(buf&: ctx->prealloc_y); |
| 11979 | ggml_vk_destroy_buffer(buf&: ctx->prealloc_split_k); |
| 11980 | ctx->prealloc_y_last_pipeline_used = nullptr; |
| 11981 | |
| 11982 | ctx->prealloc_size_x = 0; |
| 11983 | ctx->prealloc_size_y = 0; |
| 11984 | ctx->prealloc_size_split_k = 0; |
| 11985 | |
| 11986 | for (auto& event : ctx->gc.events) { |
| 11987 | ctx->device->device.destroyEvent(event); |
| 11988 | } |
| 11989 | ctx->gc.events.clear(); |
| 11990 | |
| 11991 | ctx->device->device.destroyFence(fence: ctx->fence); |
| 11992 | ctx->device->device.destroyFence(fence: ctx->almost_ready_fence); |
| 11993 | |
| 11994 | for (auto& pool : ctx->descriptor_pools) { |
| 11995 | ctx->device->device.destroyDescriptorPool(descriptorPool: pool); |
| 11996 | } |
| 11997 | ctx->descriptor_pools.clear(); |
| 11998 | ctx->descriptor_sets.clear(); |
| 11999 | |
| 12000 | ctx->compute_cmd_pool.destroy(device&: ctx->device->device); |
| 12001 | ctx->transfer_cmd_pool.destroy(device&: ctx->device->device); |
| 12002 | } |
| 12003 | |
| 12004 | static int ggml_vk_get_device_count() { |
| 12005 | ggml_vk_instance_init(); |
| 12006 | |
| 12007 | return vk_instance.device_indices.size(); |
| 12008 | } |
| 12009 | |
| 12010 | static void ggml_vk_get_device_description(int device, char * description, size_t description_size) { |
| 12011 | ggml_vk_instance_init(); |
| 12012 | |
| 12013 | std::vector<vk::PhysicalDevice> devices = vk_instance.instance.enumeratePhysicalDevices(); |
| 12014 | |
| 12015 | vk::PhysicalDeviceProperties props; |
| 12016 | devices[device].getProperties(pProperties: &props); |
| 12017 | |
| 12018 | snprintf(s: description, maxlen: description_size, format: "%s" , props.deviceName.data()); |
| 12019 | } |
| 12020 | |
| 12021 | // backend interface |
| 12022 | |
| 12023 | #define UNUSED GGML_UNUSED |
| 12024 | |
| 12025 | // device backend |
| 12026 | |
| 12027 | static bool ggml_backend_buffer_is_vk(ggml_backend_buffer_t buffer) { |
| 12028 | return buffer->buft->iface.get_name == ggml_backend_vk_buffer_type_name; |
| 12029 | } |
| 12030 | |
| 12031 | static void ggml_backend_vk_buffer_free_buffer(ggml_backend_buffer_t buffer) { |
| 12032 | VK_LOG_MEMORY("ggml_backend_vk_buffer_free_buffer()" ); |
| 12033 | ggml_backend_vk_buffer_context * ctx = (ggml_backend_vk_buffer_context *)buffer->context; |
| 12034 | ggml_vk_destroy_buffer(buf&: ctx->dev_buffer); |
| 12035 | delete ctx; |
| 12036 | } |
| 12037 | |
| 12038 | static void * ggml_backend_vk_buffer_get_base(ggml_backend_buffer_t buffer) { |
| 12039 | return vk_ptr_base; |
| 12040 | |
| 12041 | UNUSED(buffer); |
| 12042 | } |
| 12043 | |
| 12044 | static enum ggml_status ggml_backend_vk_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { |
| 12045 | VK_LOG_DEBUG("ggml_backend_vk_buffer_init_tensor(" << buffer << " (" << buffer->context << "), " << tensor << ")" ); |
| 12046 | if (tensor->view_src != nullptr) { |
| 12047 | GGML_ASSERT(tensor->view_src->buffer->buft == buffer->buft); |
| 12048 | } |
| 12049 | return GGML_STATUS_SUCCESS; |
| 12050 | } |
| 12051 | |
| 12052 | static void ggml_backend_vk_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { |
| 12053 | VK_LOG_DEBUG("ggml_backend_vk_buffer_memset_tensor(" << buffer << ", " << tensor << ", " << value << ", " << offset << ", " << size << ")" ); |
| 12054 | ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context; |
| 12055 | vk_buffer buf = buf_ctx->dev_buffer; |
| 12056 | |
| 12057 | uint32_t val32 = (uint32_t)value * 0x01010101; |
| 12058 | ggml_vk_buffer_memset(dst&: buf, offset: vk_tensor_offset(tensor) + tensor->view_offs + offset, c: val32, size); |
| 12059 | } |
| 12060 | |
| 12061 | static void ggml_backend_vk_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { |
| 12062 | VK_LOG_DEBUG("ggml_backend_vk_buffer_set_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")" ); |
| 12063 | ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context; |
| 12064 | vk_buffer buf = buf_ctx->dev_buffer; |
| 12065 | |
| 12066 | ggml_vk_buffer_write(dst&: buf, offset: vk_tensor_offset(tensor) + tensor->view_offs + offset, src: data, size); |
| 12067 | } |
| 12068 | |
| 12069 | static void ggml_backend_vk_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { |
| 12070 | VK_LOG_DEBUG("ggml_backend_vk_buffer_get_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")" ); |
| 12071 | ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context; |
| 12072 | |
| 12073 | vk_buffer buf = buf_ctx->dev_buffer; |
| 12074 | |
| 12075 | ggml_vk_buffer_read(src&: buf, offset: vk_tensor_offset(tensor) + tensor->view_offs + offset, dst: data, size); |
| 12076 | } |
| 12077 | |
| 12078 | static bool ggml_backend_vk_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) { |
| 12079 | if (ggml_backend_buffer_is_vk(buffer: src->buffer)) { |
| 12080 | ggml_backend_vk_buffer_context * src_buf_ctx = (ggml_backend_vk_buffer_context *)src->buffer->context; |
| 12081 | ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; |
| 12082 | |
| 12083 | vk_buffer src_buf = src_buf_ctx->dev_buffer; |
| 12084 | vk_buffer dst_buf = dst_buf_ctx->dev_buffer; |
| 12085 | |
| 12086 | ggml_vk_buffer_copy(dst&: dst_buf, dst_offset: vk_tensor_offset(tensor: dst) + dst->view_offs, src&: src_buf, src_offset: vk_tensor_offset(tensor: src) + src->view_offs, size: ggml_nbytes(tensor: src)); |
| 12087 | |
| 12088 | return true; |
| 12089 | } |
| 12090 | return false; |
| 12091 | |
| 12092 | UNUSED(buffer); |
| 12093 | } |
| 12094 | |
| 12095 | static void ggml_backend_vk_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { |
| 12096 | ggml_backend_vk_buffer_context * ctx = (ggml_backend_vk_buffer_context *)buffer->context; |
| 12097 | |
| 12098 | ggml_vk_buffer_memset(dst&: ctx->dev_buffer, offset: 0, c: value, size: buffer->size); |
| 12099 | } |
| 12100 | |
| 12101 | static ggml_backend_buffer_i ggml_backend_vk_buffer_interface = { |
| 12102 | /* .free_buffer = */ ggml_backend_vk_buffer_free_buffer, |
| 12103 | /* .get_base = */ ggml_backend_vk_buffer_get_base, |
| 12104 | /* .init_tensor = */ ggml_backend_vk_buffer_init_tensor, |
| 12105 | /* .memset_tensor = */ ggml_backend_vk_buffer_memset_tensor, |
| 12106 | /* .set_tensor = */ ggml_backend_vk_buffer_set_tensor, |
| 12107 | /* .get_tensor = */ ggml_backend_vk_buffer_get_tensor, |
| 12108 | /* .cpy_tensor = */ ggml_backend_vk_buffer_cpy_tensor, |
| 12109 | /* .clear = */ ggml_backend_vk_buffer_clear, |
| 12110 | /* .reset = */ NULL, |
| 12111 | }; |
| 12112 | |
| 12113 | // vk buffer type |
| 12114 | static const char * ggml_backend_vk_buffer_type_name(ggml_backend_buffer_type_t buft) { |
| 12115 | ggml_backend_vk_buffer_type_context * ctx = (ggml_backend_vk_buffer_type_context *)buft->context; |
| 12116 | |
| 12117 | return ctx->name.c_str(); |
| 12118 | } |
| 12119 | |
| 12120 | static ggml_backend_buffer_t ggml_backend_vk_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { |
| 12121 | VK_LOG_MEMORY("ggml_backend_vk_buffer_type_alloc_buffer(" << size << ")" ); |
| 12122 | ggml_backend_vk_buffer_type_context * ctx = (ggml_backend_vk_buffer_type_context *) buft->context; |
| 12123 | |
| 12124 | vk_buffer dev_buffer = nullptr; |
| 12125 | try { |
| 12126 | dev_buffer = ggml_vk_create_buffer_device(device&: ctx->device, size); |
| 12127 | } catch (const vk::SystemError& e) { |
| 12128 | return nullptr; |
| 12129 | } |
| 12130 | |
| 12131 | ggml_backend_vk_buffer_context * bufctx = new ggml_backend_vk_buffer_context(ctx->device, std::move(dev_buffer), ctx->name); |
| 12132 | |
| 12133 | return ggml_backend_buffer_init(buft, iface: ggml_backend_vk_buffer_interface, context: bufctx, size); |
| 12134 | } |
| 12135 | |
| 12136 | static size_t ggml_backend_vk_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { |
| 12137 | ggml_backend_vk_buffer_type_context * ctx = (ggml_backend_vk_buffer_type_context *) buft->context; |
| 12138 | return ctx->device->properties.limits.minStorageBufferOffsetAlignment; |
| 12139 | } |
| 12140 | |
| 12141 | static size_t ggml_backend_vk_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) { |
| 12142 | ggml_backend_vk_buffer_type_context * ctx = (ggml_backend_vk_buffer_type_context *) buft->context; |
| 12143 | return ctx->device->suballocation_block_size; |
| 12144 | } |
| 12145 | |
| 12146 | static size_t ggml_backend_vk_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) { |
| 12147 | return ggml_nbytes(tensor); |
| 12148 | |
| 12149 | UNUSED(buft); |
| 12150 | } |
| 12151 | |
| 12152 | ggml_backend_buffer_type_t ggml_backend_vk_buffer_type(size_t dev_num) { |
| 12153 | ggml_vk_instance_init(); |
| 12154 | |
| 12155 | VK_LOG_DEBUG("ggml_backend_vk_buffer_type(" << dev_num << ")" ); |
| 12156 | |
| 12157 | vk_device dev = ggml_vk_get_device(idx: dev_num); |
| 12158 | |
| 12159 | return &dev->buffer_type; |
| 12160 | } |
| 12161 | |
| 12162 | // host buffer type |
| 12163 | |
| 12164 | static const char * ggml_backend_vk_host_buffer_type_name(ggml_backend_buffer_type_t buft) { |
| 12165 | return GGML_VK_NAME "_Host" ; |
| 12166 | |
| 12167 | UNUSED(buft); |
| 12168 | } |
| 12169 | |
| 12170 | static const char * ggml_backend_vk_host_buffer_name(ggml_backend_buffer_t buffer) { |
| 12171 | return GGML_VK_NAME "_Host" ; |
| 12172 | |
| 12173 | UNUSED(buffer); |
| 12174 | } |
| 12175 | |
| 12176 | static void ggml_backend_vk_host_buffer_free_buffer(ggml_backend_buffer_t buffer) { |
| 12177 | VK_LOG_MEMORY("ggml_backend_vk_host_buffer_free_buffer()" ); |
| 12178 | ggml_vk_host_free(device&: vk_instance.devices[0], ptr: buffer->context); |
| 12179 | } |
| 12180 | |
| 12181 | static ggml_backend_buffer_t ggml_backend_vk_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { |
| 12182 | VK_LOG_MEMORY("ggml_backend_vk_host_buffer_type_alloc_buffer(" << size << ")" ); |
| 12183 | |
| 12184 | size += 32; // Behave like the CPU buffer type |
| 12185 | void * ptr = nullptr; |
| 12186 | try { |
| 12187 | ptr = ggml_vk_host_malloc(device&: vk_instance.devices[0], size); |
| 12188 | } catch (vk::SystemError& e) { |
| 12189 | GGML_LOG_WARN("ggml_vulkan: Failed to allocate pinned memory (%s)\n" , e.what()); |
| 12190 | // fallback to cpu buffer |
| 12191 | return ggml_backend_buft_alloc_buffer(buft: ggml_backend_cpu_buffer_type(), size); |
| 12192 | } |
| 12193 | |
| 12194 | ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(ptr, size); |
| 12195 | buffer->buft = buft; |
| 12196 | buffer->iface.free_buffer = ggml_backend_vk_host_buffer_free_buffer; |
| 12197 | |
| 12198 | return buffer; |
| 12199 | |
| 12200 | UNUSED(buft); |
| 12201 | } |
| 12202 | |
| 12203 | static size_t ggml_backend_vk_host_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { |
| 12204 | return vk_instance.devices[0]->properties.limits.minMemoryMapAlignment; |
| 12205 | |
| 12206 | UNUSED(buft); |
| 12207 | } |
| 12208 | |
| 12209 | static size_t ggml_backend_vk_host_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) { |
| 12210 | return vk_instance.devices[0]->suballocation_block_size; |
| 12211 | |
| 12212 | UNUSED(buft); |
| 12213 | } |
| 12214 | |
| 12215 | // Should be changed to return device-specific host buffer type |
| 12216 | // but that probably requires changes in llama.cpp |
| 12217 | ggml_backend_buffer_type_t ggml_backend_vk_host_buffer_type() { |
| 12218 | static struct ggml_backend_buffer_type ggml_backend_vk_buffer_type_host = { |
| 12219 | /* .iface = */ { |
| 12220 | /* .get_name = */ ggml_backend_vk_host_buffer_type_name, |
| 12221 | /* .alloc_buffer = */ ggml_backend_vk_host_buffer_type_alloc_buffer, |
| 12222 | /* .get_alignment = */ ggml_backend_vk_host_buffer_type_get_alignment, |
| 12223 | /* .get_max_size = */ ggml_backend_vk_host_buffer_type_get_max_size, |
| 12224 | /* .get_alloc_size = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size, |
| 12225 | /* .is_host = */ ggml_backend_cpu_buffer_type()->iface.is_host, |
| 12226 | }, |
| 12227 | /* .device = */ ggml_backend_reg_dev_get(reg: ggml_backend_vk_reg(), index: 0), |
| 12228 | /* .context = */ nullptr, |
| 12229 | }; |
| 12230 | |
| 12231 | // Make sure device 0 is initialized |
| 12232 | ggml_vk_instance_init(); |
| 12233 | ggml_vk_get_device(idx: 0); |
| 12234 | |
| 12235 | return &ggml_backend_vk_buffer_type_host; |
| 12236 | } |
| 12237 | |
| 12238 | |
| 12239 | // backend |
| 12240 | |
| 12241 | static const char * ggml_backend_vk_name(ggml_backend_t backend) { |
| 12242 | ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; |
| 12243 | |
| 12244 | return ctx->name.c_str(); |
| 12245 | } |
| 12246 | |
| 12247 | static void ggml_backend_vk_free(ggml_backend_t backend) { |
| 12248 | ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; |
| 12249 | VK_LOG_DEBUG("ggml_backend_vk_free(" << ctx->name << ")" ); |
| 12250 | |
| 12251 | ggml_vk_cleanup(ctx); |
| 12252 | |
| 12253 | delete ctx; |
| 12254 | delete backend; |
| 12255 | } |
| 12256 | |
| 12257 | static ggml_backend_buffer_type_t ggml_backend_vk_get_default_buffer_type(ggml_backend_t backend) { |
| 12258 | ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; |
| 12259 | |
| 12260 | return &ctx->device->buffer_type; |
| 12261 | } |
| 12262 | |
| 12263 | static void ggml_backend_vk_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { |
| 12264 | VK_LOG_DEBUG("ggml_backend_vk_set_tensor_async(" << size << ")" ); |
| 12265 | ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; |
| 12266 | GGML_ASSERT((tensor->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || tensor->buffer->buft == ggml_backend_vk_host_buffer_type()) && "unsupported buffer type" ); |
| 12267 | |
| 12268 | ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context; |
| 12269 | |
| 12270 | vk_context transfer_ctx; |
| 12271 | |
| 12272 | if (ctx->transfer_ctx.expired()) { |
| 12273 | // Initialize new transfer context |
| 12274 | transfer_ctx = ggml_vk_create_context(ctx, p&: ctx->transfer_cmd_pool); |
| 12275 | ctx->transfer_ctx = transfer_ctx; |
| 12276 | ggml_vk_ctx_begin(device&: ctx->device, subctx&: transfer_ctx); |
| 12277 | } else { |
| 12278 | transfer_ctx = ctx->transfer_ctx.lock(); |
| 12279 | } |
| 12280 | |
| 12281 | vk_buffer buf = buf_ctx->dev_buffer; |
| 12282 | |
| 12283 | ggml_vk_buffer_write_async(subctx: transfer_ctx, dst&: buf, offset: vk_tensor_offset(tensor) + tensor->view_offs + offset, src: data, size); |
| 12284 | } |
| 12285 | |
| 12286 | static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { |
| 12287 | VK_LOG_DEBUG("ggml_backend_vk_get_tensor_async(" << size << ")" ); |
| 12288 | ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; |
| 12289 | GGML_ASSERT((tensor->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || tensor->buffer->buft == ggml_backend_vk_host_buffer_type()) && "unsupported buffer type" ); |
| 12290 | |
| 12291 | ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context; |
| 12292 | |
| 12293 | vk_context transfer_ctx; |
| 12294 | |
| 12295 | if (ctx->transfer_ctx.expired()) { |
| 12296 | // Initialize new transfer context |
| 12297 | transfer_ctx = ggml_vk_create_context(ctx, p&: ctx->transfer_cmd_pool); |
| 12298 | ctx->transfer_ctx = transfer_ctx; |
| 12299 | ggml_vk_ctx_begin(device&: ctx->device, subctx&: transfer_ctx); |
| 12300 | } else { |
| 12301 | transfer_ctx = ctx->transfer_ctx.lock(); |
| 12302 | } |
| 12303 | |
| 12304 | vk_buffer buf = buf_ctx->dev_buffer; |
| 12305 | |
| 12306 | ggml_vk_buffer_read_async(subctx: transfer_ctx, src&: buf, offset: vk_tensor_offset(tensor) + tensor->view_offs + offset, dst: data, size); |
| 12307 | } |
| 12308 | |
| 12309 | static bool ggml_backend_vk_cpy_tensor_async(ggml_backend_t backend, const ggml_tensor * src, ggml_tensor * dst) { |
| 12310 | VK_LOG_DEBUG("ggml_backend_vk_cpy_tensor_async()" ); |
| 12311 | ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; |
| 12312 | if ((dst->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || dst->buffer->buft == ggml_backend_vk_host_buffer_type()) && ggml_backend_buffer_is_vk(buffer: src->buffer)) { |
| 12313 | ggml_backend_vk_buffer_context * src_buf_ctx = (ggml_backend_vk_buffer_context *)src->buffer->context; |
| 12314 | ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; |
| 12315 | |
| 12316 | vk_context transfer_ctx; |
| 12317 | |
| 12318 | if (ctx->transfer_ctx.expired()) { |
| 12319 | // Initialize new transfer context |
| 12320 | transfer_ctx = ggml_vk_create_context(ctx, p&: ctx->transfer_cmd_pool); |
| 12321 | ctx->transfer_ctx = transfer_ctx; |
| 12322 | ggml_vk_ctx_begin(device&: ctx->device, subctx&: transfer_ctx); |
| 12323 | } else { |
| 12324 | transfer_ctx = ctx->transfer_ctx.lock(); |
| 12325 | } |
| 12326 | |
| 12327 | vk_buffer src_buf = src_buf_ctx->dev_buffer; |
| 12328 | vk_buffer dst_buf = dst_buf_ctx->dev_buffer; |
| 12329 | |
| 12330 | ggml_vk_buffer_copy_async(ctx&: transfer_ctx, dst&: dst_buf, dst_offset: vk_tensor_offset(tensor: dst) + dst->view_offs, src&: src_buf, src_offset: vk_tensor_offset(tensor: src) + src->view_offs, size: ggml_nbytes(tensor: src)); |
| 12331 | return true; |
| 12332 | } |
| 12333 | |
| 12334 | return false; |
| 12335 | } |
| 12336 | |
| 12337 | static void ggml_backend_vk_synchronize(ggml_backend_t backend) { |
| 12338 | VK_LOG_DEBUG("ggml_backend_vk_synchronize()" ); |
| 12339 | ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; |
| 12340 | if(ctx->transfer_ctx.expired()) { |
| 12341 | return; |
| 12342 | } |
| 12343 | |
| 12344 | vk_context transfer_ctx = ctx->transfer_ctx.lock(); |
| 12345 | |
| 12346 | ggml_vk_ctx_end(ctx&: transfer_ctx); |
| 12347 | |
| 12348 | for (auto& cpy : transfer_ctx->in_memcpys) { |
| 12349 | memcpy(dest: cpy.dst, src: cpy.src, n: cpy.n); |
| 12350 | } |
| 12351 | |
| 12352 | ggml_vk_submit(ctx&: transfer_ctx, fence: ctx->fence); |
| 12353 | ggml_vk_wait_for_fence(ctx); |
| 12354 | |
| 12355 | for (auto& cpy : transfer_ctx->out_memcpys) { |
| 12356 | memcpy(dest: cpy.dst, src: cpy.src, n: cpy.n); |
| 12357 | } |
| 12358 | |
| 12359 | ctx->transfer_ctx.reset(); |
| 12360 | } |
| 12361 | |
| 12362 | static bool ggml_vk_is_empty(ggml_tensor * node) { |
| 12363 | return ggml_is_empty(tensor: node) || node->op == GGML_OP_NONE || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE; |
| 12364 | } |
| 12365 | |
| 12366 | static bool ggml_vk_can_fuse(const ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops) { |
| 12367 | if (!ggml_can_fuse(cgraph, node_idx, ops)) { |
| 12368 | return false; |
| 12369 | } |
| 12370 | |
| 12371 | if (ops.size() == 2 && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) { |
| 12372 | // additional constraints specific to this fusion |
| 12373 | const ggml_tensor *rms_norm = cgraph->nodes[node_idx]; |
| 12374 | const ggml_tensor *mul = cgraph->nodes[node_idx + 1]; |
| 12375 | |
| 12376 | GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32); |
| 12377 | GGML_ASSERT(rms_norm->type == GGML_TYPE_F32); |
| 12378 | // rms_norm only supports f32 |
| 12379 | if (mul->src[0]->type != GGML_TYPE_F32 || |
| 12380 | mul->src[1]->type != GGML_TYPE_F32 || |
| 12381 | mul->type != GGML_TYPE_F32) { |
| 12382 | return false; |
| 12383 | } |
| 12384 | // if rms_norm is the B operand, then we don't handle broadcast |
| 12385 | if (rms_norm == mul->src[1] && |
| 12386 | !ggml_are_same_shape(t0: mul->src[0], t1: rms_norm)) { |
| 12387 | return false; |
| 12388 | } |
| 12389 | // rms_norm shader assumes contiguous rows |
| 12390 | if (!ggml_is_contiguous_rows(tensor: mul->src[0]) || !ggml_is_contiguous_rows(tensor: mul->src[1])) { |
| 12391 | return false; |
| 12392 | } |
| 12393 | } |
| 12394 | if (ops.size() == 2 && ops.begin()[0] == GGML_OP_MUL_MAT && ops.begin()[1] == GGML_OP_ADD) { |
| 12395 | // additional constraints specific to this fusion |
| 12396 | const ggml_tensor *mul = cgraph->nodes[node_idx]; |
| 12397 | const ggml_tensor *add = cgraph->nodes[node_idx + 1]; |
| 12398 | const ggml_tensor *bias = add->src[0] == mul ? add->src[1] : add->src[0]; |
| 12399 | |
| 12400 | // mat-vec only |
| 12401 | if (ggml_nrows(tensor: mul) != 1) { |
| 12402 | return false; |
| 12403 | } |
| 12404 | // shaders assume the types match |
| 12405 | if (mul->type != bias->type) { |
| 12406 | return false; |
| 12407 | } |
| 12408 | // shaders reuse the D shape for bias |
| 12409 | if (!ggml_are_same_shape(t0: mul, t1: bias) || |
| 12410 | !ggml_are_same_stride(t0: mul, t1: bias)) { |
| 12411 | return false; |
| 12412 | } |
| 12413 | // unaligned bias isn't handled |
| 12414 | if (get_misalign_bytes(ctx, t: bias) != 0) { |
| 12415 | return false; |
| 12416 | } |
| 12417 | } |
| 12418 | if (ops.size() == 2 && ops.begin()[0] == GGML_OP_MUL_MAT_ID && ops.begin()[1] == GGML_OP_ADD_ID) { |
| 12419 | // additional constraints specific to this fusion |
| 12420 | const ggml_tensor *mul = cgraph->nodes[node_idx]; |
| 12421 | const ggml_tensor *add = cgraph->nodes[node_idx + 1]; |
| 12422 | const ggml_tensor *bias = add->src[1]; |
| 12423 | |
| 12424 | if (mul != add->src[0]) { |
| 12425 | return false; |
| 12426 | } |
| 12427 | // mat-vec only |
| 12428 | if (!ggml_vk_use_mul_mat_vec_id(cgraph, node_idx)) { |
| 12429 | return false; |
| 12430 | } |
| 12431 | // shaders assume the types match |
| 12432 | if (mul->type != bias->type) { |
| 12433 | return false; |
| 12434 | } |
| 12435 | // shaders assume the bias is contiguous |
| 12436 | if (!ggml_is_contiguous(tensor: bias)) { |
| 12437 | return false; |
| 12438 | } |
| 12439 | // the ID tensor must be the same for mul_mat_id and add_id |
| 12440 | if (mul->src[2] != add->src[2]) { |
| 12441 | return false; |
| 12442 | } |
| 12443 | // unaligned bias isn't handled |
| 12444 | if (get_misalign_bytes(ctx, t: bias) != 0) { |
| 12445 | return false; |
| 12446 | } |
| 12447 | } |
| 12448 | |
| 12449 | return true; |
| 12450 | } |
| 12451 | |
| 12452 | static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, |
| 12453 | int node_idx, topk_moe_mode mode) { |
| 12454 | |
| 12455 | const ggml_tensor * softmax; |
| 12456 | const ggml_tensor * weights; |
| 12457 | |
| 12458 | switch (mode) { |
| 12459 | case TOPK_MOE_EARLY_SOFTMAX_NORM: |
| 12460 | softmax = cgraph->nodes[node_idx + 0]; |
| 12461 | weights = cgraph->nodes[node_idx + 9]; |
| 12462 | break; |
| 12463 | case TOPK_MOE_EARLY_SOFTMAX: |
| 12464 | softmax = cgraph->nodes[node_idx + 0]; |
| 12465 | weights = cgraph->nodes[node_idx + 4]; |
| 12466 | break; |
| 12467 | case TOPK_MOE_LATE_SOFTMAX: |
| 12468 | softmax = cgraph->nodes[node_idx + 4]; |
| 12469 | weights = cgraph->nodes[node_idx + 5]; |
| 12470 | break; |
| 12471 | default: |
| 12472 | return false; |
| 12473 | } |
| 12474 | |
| 12475 | const float * op_params = (const float *)softmax->op_params; |
| 12476 | |
| 12477 | float scale = op_params[0]; |
| 12478 | float max_bias = op_params[1]; |
| 12479 | |
| 12480 | if (!ggml_is_contiguous(tensor: softmax->src[0]) || !ggml_is_contiguous(tensor: weights)) { |
| 12481 | return false; |
| 12482 | } |
| 12483 | |
| 12484 | if (scale != 1.0f || max_bias != 0.0f) { |
| 12485 | return false; |
| 12486 | } |
| 12487 | |
| 12488 | // don't fuse when masks or sinks are present |
| 12489 | if (softmax->src[1] || softmax->src[2]) { |
| 12490 | return false; |
| 12491 | } |
| 12492 | |
| 12493 | const int n_expert = softmax->ne[0]; |
| 12494 | // n_expert must be a power of 2 |
| 12495 | if (!is_pow2(x: n_expert) || n_expert > (1 << (num_topk_moe_pipelines-1))) { |
| 12496 | return false; |
| 12497 | } |
| 12498 | |
| 12499 | if (!ctx->device->subgroup_arithmetic || |
| 12500 | !ctx->device->subgroup_shuffle || |
| 12501 | !ctx->device->subgroup_require_full_support || |
| 12502 | ctx->device->disable_fusion) { |
| 12503 | return false; |
| 12504 | } |
| 12505 | |
| 12506 | return true; |
| 12507 | } |
| 12508 | |
| 12509 | static bool ggml_vk_can_fuse_rope_set_rows(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, |
| 12510 | int node_idx) { |
| 12511 | GGML_UNUSED(ctx); |
| 12512 | const ggml_tensor *rope = cgraph->nodes[node_idx + 0]; |
| 12513 | const ggml_tensor *view = cgraph->nodes[node_idx + 1]; |
| 12514 | const ggml_tensor *set_rows = cgraph->nodes[node_idx + 2]; |
| 12515 | |
| 12516 | // ne3 not tested |
| 12517 | if (rope->src[0]->ne[3] != 1) { |
| 12518 | return false; |
| 12519 | } |
| 12520 | |
| 12521 | if (set_rows->type != GGML_TYPE_F32 && set_rows->type != GGML_TYPE_F16) { |
| 12522 | return false; |
| 12523 | } |
| 12524 | |
| 12525 | if (set_rows->src[1]->type != GGML_TYPE_I64) { |
| 12526 | return false; |
| 12527 | } |
| 12528 | |
| 12529 | // The view should flatten two dims of rope into one dim |
| 12530 | if (!ggml_is_contiguous(tensor: view) || |
| 12531 | view->ne[0] != rope->ne[0] * rope->ne[1]) { |
| 12532 | return false; |
| 12533 | } |
| 12534 | |
| 12535 | // Only norm/neox shaders have the fusion code |
| 12536 | const int mode = ((const int32_t *) rope->op_params)[2]; |
| 12537 | if (mode != GGML_ROPE_TYPE_NORMAL && mode != GGML_ROPE_TYPE_NEOX) { |
| 12538 | return false; |
| 12539 | } |
| 12540 | |
| 12541 | return true; |
| 12542 | } |
| 12543 | |
| 12544 | // Check whether the tensors overlap in memory but are not equal. |
| 12545 | // Fusions can potenitally overwrite src tensors in ways that are not prevented |
| 12546 | // by ggml-alloc. If the fusion is entirely elementwise, then it's OK for them |
| 12547 | // to overlap if they are exactly equal. |
| 12548 | // XXX TODO this check is probably missing from several fusion optimizations. |
| 12549 | static bool ggml_vk_tensors_overlap_but_not_equal(const ggml_tensor * a, const ggml_tensor * b) { |
| 12550 | ggml_backend_vk_buffer_context * a_buf_ctx = (ggml_backend_vk_buffer_context *)a->buffer->context; |
| 12551 | vk_buffer a_buf = a_buf_ctx->dev_buffer; |
| 12552 | ggml_backend_vk_buffer_context * b_buf_ctx = (ggml_backend_vk_buffer_context *)b->buffer->context; |
| 12553 | vk_buffer b_buf = b_buf_ctx->dev_buffer; |
| 12554 | if (a_buf == b_buf) { |
| 12555 | auto a_base = vk_tensor_offset(tensor: a) + a->view_offs; |
| 12556 | auto a_size = ggml_nbytes(tensor: a); |
| 12557 | auto b_base = vk_tensor_offset(tensor: b) + b->view_offs; |
| 12558 | auto b_size = ggml_nbytes(tensor: b); |
| 12559 | |
| 12560 | if (a_base == b_base && a_size == b_size) { |
| 12561 | return false; |
| 12562 | } |
| 12563 | |
| 12564 | if ((b_base <= a_base && a_base < b_base + b_size) || |
| 12565 | (a_base <= b_base && b_base < a_base + a_size)) { |
| 12566 | return true; |
| 12567 | } |
| 12568 | } |
| 12569 | return false; |
| 12570 | } |
| 12571 | |
| 12572 | static bool ggml_vk_can_fuse_rms_norm_mul_rope(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, |
| 12573 | int node_idx) { |
| 12574 | GGML_UNUSED(ctx); |
| 12575 | const ggml_tensor *rms = cgraph->nodes[node_idx + 0]; |
| 12576 | const ggml_tensor *mul = cgraph->nodes[node_idx + 1]; |
| 12577 | const ggml_tensor *rope = cgraph->nodes[node_idx + 2]; |
| 12578 | |
| 12579 | const int mode = ((const int32_t *) rope->op_params)[2]; |
| 12580 | |
| 12581 | // noncontig tensors aren't tested, and don't seem common in practice |
| 12582 | if (!ggml_is_contiguous(tensor: rms) || |
| 12583 | !ggml_is_contiguous(tensor: mul) || |
| 12584 | !ggml_is_contiguous(tensor: rope)) { |
| 12585 | return false; |
| 12586 | } |
| 12587 | |
| 12588 | // only norm/neox are handled in the shader |
| 12589 | if (mode != GGML_ROPE_TYPE_NEOX && mode != GGML_ROPE_TYPE_NORMAL) { |
| 12590 | return false; |
| 12591 | } |
| 12592 | |
| 12593 | // shared memory size for passing data from mul->rope |
| 12594 | if (mul->ne[0] > 1024) { |
| 12595 | return false; |
| 12596 | } |
| 12597 | |
| 12598 | // must not overwrite srcs in a way that's not elementwise |
| 12599 | ggml_tensor *other_src = mul->src[0] == rms ? mul->src[1] : mul->src[0]; |
| 12600 | if (ggml_vk_tensors_overlap_but_not_equal(a: rms->src[0], b: rope) || |
| 12601 | ggml_vk_tensors_overlap_but_not_equal(a: other_src, b: rope)) { |
| 12602 | return false; |
| 12603 | } |
| 12604 | |
| 12605 | return true; |
| 12606 | } |
| 12607 | |
| 12608 | static uint32_t ggml_vk_fuse_multi_add(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, int node_idx) { |
| 12609 | |
| 12610 | const ggml_tensor *first_node = cgraph->nodes[node_idx]; |
| 12611 | if (first_node->op != GGML_OP_ADD) { |
| 12612 | return 0; |
| 12613 | } |
| 12614 | |
| 12615 | if (!ctx->device->multi_add) { |
| 12616 | return 0; |
| 12617 | } |
| 12618 | |
| 12619 | int32_t num_adds = 1; |
| 12620 | while (node_idx + num_adds < cgraph->n_nodes && |
| 12621 | cgraph->nodes[node_idx + num_adds]->op == GGML_OP_ADD && |
| 12622 | num_adds < MAX_FUSED_ADDS) { |
| 12623 | num_adds++; |
| 12624 | } |
| 12625 | |
| 12626 | // The shader currently requires same shapes (but different strides are allowed), |
| 12627 | // everything f32, and no misalignment |
| 12628 | for (int32_t i = 0; i < num_adds; ++i) { |
| 12629 | const ggml_tensor *next_node = cgraph->nodes[node_idx + i]; |
| 12630 | if (!ggml_are_same_shape(t0: first_node, t1: next_node->src[0]) || |
| 12631 | !ggml_are_same_shape(t0: first_node, t1: next_node->src[1]) || |
| 12632 | next_node->type != GGML_TYPE_F32 || |
| 12633 | next_node->src[0]->type != GGML_TYPE_F32 || |
| 12634 | next_node->src[1]->type != GGML_TYPE_F32 || |
| 12635 | get_misalign_bytes(ctx, t: next_node) || |
| 12636 | get_misalign_bytes(ctx, t: next_node->src[0]) || |
| 12637 | get_misalign_bytes(ctx, t: next_node->src[1])) { |
| 12638 | num_adds = i; |
| 12639 | } |
| 12640 | } |
| 12641 | |
| 12642 | // Verify we can fuse these |
| 12643 | ggml_op adds[MAX_FUSED_ADDS]; |
| 12644 | for (int32_t i = 0; i < num_adds; ++i) { |
| 12645 | adds[i] = GGML_OP_ADD; |
| 12646 | } |
| 12647 | |
| 12648 | // decrease num_adds if they can't all be fused |
| 12649 | while (num_adds > 1 && !ggml_can_fuse(cgraph, node_idx, ops: adds, num_ops: num_adds)) { |
| 12650 | num_adds--; |
| 12651 | } |
| 12652 | |
| 12653 | // a single add is not "fused", so just return zero |
| 12654 | if (num_adds == 1) { |
| 12655 | return 0; |
| 12656 | } |
| 12657 | return num_adds; |
| 12658 | } |
| 12659 | |
| 12660 | static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { |
| 12661 | VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)" ); |
| 12662 | ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; |
| 12663 | |
| 12664 | if (vk_instance.debug_utils_support) { |
| 12665 | vk::DebugUtilsLabelEXT dul = {}; |
| 12666 | dul.pLabelName = "ggml_backend_vk_graph_compute" ; |
| 12667 | dul.color = std::array<float,4>{1.0f, 1.0f, 1.0f, 1.0f}; |
| 12668 | vk_instance.pfn_vkQueueBeginDebugUtilsLabelEXT(ctx->device->compute_queue.queue, reinterpret_cast<VkDebugUtilsLabelEXT*>(&dul)); |
| 12669 | } |
| 12670 | |
| 12671 | ctx->prealloc_size_add_rms_partials_offset = 0; |
| 12672 | ctx->do_add_rms_partials = false; |
| 12673 | ctx->do_add_rms_partials_offset_calculation = false; |
| 12674 | |
| 12675 | int last_node = cgraph->n_nodes - 1; |
| 12676 | |
| 12677 | // If the last op in the cgraph isn't backend GPU, the command buffer doesn't get closed properly |
| 12678 | while (last_node > 0 && ggml_vk_is_empty(node: cgraph->nodes[last_node])) { |
| 12679 | last_node -= 1; |
| 12680 | } |
| 12681 | |
| 12682 | // Reserve tensor context space for all nodes |
| 12683 | ctx->tensor_ctxs.resize(new_size: cgraph->n_nodes); |
| 12684 | |
| 12685 | bool first_node_in_batch = true; // true if next node will be first node in a batch |
| 12686 | int submit_node_idx = 0; // index to first node in a batch |
| 12687 | |
| 12688 | vk_context compute_ctx; |
| 12689 | if (vk_perf_logger_enabled) { |
| 12690 | // allocate/resize the query pool |
| 12691 | if (ctx->device->num_queries < cgraph->n_nodes + 1) { |
| 12692 | if (ctx->device->query_pool) { |
| 12693 | ctx->device->device.destroyQueryPool(queryPool: ctx->device->query_pool); |
| 12694 | } |
| 12695 | vk::QueryPoolCreateInfo query_create_info; |
| 12696 | query_create_info.queryType = vk::QueryType::eTimestamp; |
| 12697 | query_create_info.queryCount = cgraph->n_nodes + 100; |
| 12698 | ctx->device->query_pool = ctx->device->device.createQueryPool(createInfo: query_create_info); |
| 12699 | ctx->device->num_queries = query_create_info.queryCount; |
| 12700 | } |
| 12701 | |
| 12702 | ctx->device->device.resetQueryPool(queryPool: ctx->device->query_pool, firstQuery: 0, queryCount: cgraph->n_nodes+1); |
| 12703 | |
| 12704 | GGML_ASSERT(ctx->compute_ctx.expired()); |
| 12705 | compute_ctx = ggml_vk_create_context(ctx, p&: ctx->compute_cmd_pool); |
| 12706 | ctx->compute_ctx = compute_ctx; |
| 12707 | ggml_vk_ctx_begin(device&: ctx->device, subctx&: compute_ctx); |
| 12708 | compute_ctx->s->buffer.writeTimestamp(pipelineStage: vk::PipelineStageFlagBits::eAllCommands, queryPool: ctx->device->query_pool, query: 0); |
| 12709 | } |
| 12710 | |
| 12711 | ctx->prealloc_y_last_pipeline_used = nullptr; |
| 12712 | ctx->prealloc_y_last_tensor_used = nullptr; |
| 12713 | |
| 12714 | if (ctx->prealloc_size_add_rms_partials) { |
| 12715 | ggml_vk_preallocate_buffers(ctx, subctx: nullptr); |
| 12716 | if (ctx->compute_ctx.expired()) { |
| 12717 | compute_ctx = ggml_vk_create_context(ctx, p&: ctx->compute_cmd_pool); |
| 12718 | ctx->compute_ctx = compute_ctx; |
| 12719 | ggml_vk_ctx_begin(device&: ctx->device, subctx&: compute_ctx); |
| 12720 | } else { |
| 12721 | compute_ctx = ctx->compute_ctx.lock(); |
| 12722 | } |
| 12723 | // initialize partial sums to zero. |
| 12724 | ggml_vk_buffer_memset_async(ctx&: compute_ctx, dst&: ctx->prealloc_add_rms_partials, offset: 0, c: 0, size: ctx->prealloc_size_add_rms_partials); |
| 12725 | ggml_vk_sync_buffers(ctx, subctx&: compute_ctx); |
| 12726 | } |
| 12727 | |
| 12728 | // Submit after enough work has accumulated, to overlap CPU cmdbuffer generation with GPU execution. |
| 12729 | // Estimate the amount of matmul work by looking at the weight matrix size, and submit every 100MB |
| 12730 | // (and scaled down based on model size, so smaller models submit earlier). |
| 12731 | // Also submit at least every 100 nodes, in case there are workloads without as much matmul. |
| 12732 | int nodes_per_submit = 100; |
| 12733 | int submitted_nodes = 0; |
| 12734 | int submit_count = 0; |
| 12735 | uint64_t mul_mat_bytes = 0; |
| 12736 | uint64_t total_mul_mat_bytes = 0; |
| 12737 | uint64_t mul_mat_bytes_per_submit = std::min(a: uint64_t(100*1000*1000), b: ctx->last_total_mul_mat_bytes / 40u); |
| 12738 | for (int i = 0; i < cgraph->n_nodes; i++) { |
| 12739 | if (first_node_in_batch) { |
| 12740 | submit_node_idx = i; |
| 12741 | } |
| 12742 | |
| 12743 | if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) { |
| 12744 | auto bytes = ggml_nbytes(tensor: cgraph->nodes[i]->src[0]); |
| 12745 | mul_mat_bytes += bytes; |
| 12746 | total_mul_mat_bytes += bytes; |
| 12747 | } |
| 12748 | |
| 12749 | if (!ctx->device->disable_fusion) { |
| 12750 | uint32_t num_adds = ggml_vk_fuse_multi_add(ctx, cgraph, node_idx: i); |
| 12751 | if (num_adds) { |
| 12752 | ctx->num_additional_fused_ops = num_adds - 1; |
| 12753 | } else if (ggml_vk_can_fuse(ctx, cgraph, node_idx: i, ops: { GGML_OP_MUL_MAT, GGML_OP_ADD })) { |
| 12754 | ctx->num_additional_fused_ops = 1; |
| 12755 | } else if (ggml_vk_can_fuse(ctx, cgraph, node_idx: i, ops: { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID })) { |
| 12756 | ctx->num_additional_fused_ops = 1; |
| 12757 | } else if (ggml_can_fuse_subgraph(cgraph, start_idx: i, ops: { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, outputs: { i + 4 }) && |
| 12758 | ggml_check_edges(cgraph, start_idx: i, edges: rms_norm_mul_rope_view_set_rows_edges) && |
| 12759 | ggml_vk_can_fuse_rms_norm_mul_rope(ctx, cgraph, node_idx: i) && |
| 12760 | ggml_vk_can_fuse_rope_set_rows(ctx, cgraph, node_idx: i + 2)) { |
| 12761 | ctx->num_additional_fused_ops = 4; |
| 12762 | } else if (ggml_vk_can_fuse(ctx, cgraph, node_idx: i, ops: { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ROPE })&& |
| 12763 | ggml_vk_can_fuse_rms_norm_mul_rope(ctx, cgraph, node_idx: i)) { |
| 12764 | ctx->num_additional_fused_ops = 2; |
| 12765 | } else if (ggml_vk_can_fuse(ctx, cgraph, node_idx: i, ops: { GGML_OP_RMS_NORM, GGML_OP_MUL })) { |
| 12766 | ctx->num_additional_fused_ops = 1; |
| 12767 | } else if (ggml_can_fuse_subgraph(cgraph, start_idx: i, ops: { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, outputs: { i + 2 }) && |
| 12768 | ggml_check_edges(cgraph, start_idx: i, edges: rope_view_set_rows_edges) && |
| 12769 | ggml_vk_can_fuse_rope_set_rows(ctx, cgraph, node_idx: i)) { |
| 12770 | ctx->num_additional_fused_ops = 2; |
| 12771 | } else if (ggml_can_fuse_subgraph(cgraph, start_idx: i, ops: topk_moe_early_softmax_norm, outputs: { i + 3, i + 9 }) && |
| 12772 | ggml_check_edges(cgraph, start_idx: i, edges: topk_moe_early_softmax_norm_edges) && |
| 12773 | ggml_vk_can_fuse_topk_moe(ctx, cgraph, node_idx: i, mode: TOPK_MOE_EARLY_SOFTMAX_NORM)) { |
| 12774 | ctx->num_additional_fused_ops = topk_moe_early_softmax_norm.size() - 1; |
| 12775 | // view of argsort writes to memory |
| 12776 | ctx->fused_ops_write_mask |= 1 << 3; |
| 12777 | } else if (ggml_can_fuse_subgraph(cgraph, start_idx: i, ops: topk_moe_early_softmax, outputs: { i + 3, i + 4 }) && |
| 12778 | ggml_check_edges(cgraph, start_idx: i, edges: topk_moe_early_softmax_edges) && |
| 12779 | ggml_vk_can_fuse_topk_moe(ctx, cgraph, node_idx: i, mode: TOPK_MOE_EARLY_SOFTMAX)) { |
| 12780 | ctx->num_additional_fused_ops = topk_moe_early_softmax.size() - 1; |
| 12781 | // view of argsort writes to memory |
| 12782 | ctx->fused_ops_write_mask |= 1 << 3; |
| 12783 | } else if (ggml_can_fuse_subgraph(cgraph, start_idx: i, ops: topk_moe_late_softmax, outputs: { i + 1, i + 5 }) && |
| 12784 | ggml_check_edges(cgraph, start_idx: i, edges: topk_moe_late_softmax_edges) && |
| 12785 | ggml_vk_can_fuse_topk_moe(ctx, cgraph, node_idx: i, mode: TOPK_MOE_LATE_SOFTMAX)) { |
| 12786 | ctx->num_additional_fused_ops = topk_moe_late_softmax.size() - 1; |
| 12787 | // view of argsort writes to memory |
| 12788 | ctx->fused_ops_write_mask |= 1 << 1; |
| 12789 | } |
| 12790 | } |
| 12791 | ctx->fused_ops_write_mask |= 1 << ctx->num_additional_fused_ops; |
| 12792 | |
| 12793 | // Signal the almost_ready fence when the graph is mostly complete (< 20% remaining) |
| 12794 | bool almost_ready = (cgraph->n_nodes - i) < cgraph->n_nodes / 5; |
| 12795 | bool submit = (submitted_nodes >= nodes_per_submit) || |
| 12796 | (mul_mat_bytes_per_submit != 0 && mul_mat_bytes >= mul_mat_bytes_per_submit) || |
| 12797 | (i + ctx->num_additional_fused_ops >= last_node) || |
| 12798 | (almost_ready && !ctx->almost_ready_fence_pending); |
| 12799 | |
| 12800 | bool enqueued = ggml_vk_build_graph(ctx, cgraph, node_idx: i, node_begin: cgraph->nodes[submit_node_idx], node_idx_begin: submit_node_idx, last_node: i + ctx->num_additional_fused_ops >= last_node, almost_ready, submit); |
| 12801 | |
| 12802 | if (vk_perf_logger_enabled) { |
| 12803 | if (ctx->compute_ctx.expired()) { |
| 12804 | compute_ctx = ggml_vk_create_context(ctx, p&: ctx->compute_cmd_pool); |
| 12805 | ctx->compute_ctx = compute_ctx; |
| 12806 | ggml_vk_ctx_begin(device&: ctx->device, subctx&: compute_ctx); |
| 12807 | } else { |
| 12808 | compute_ctx = ctx->compute_ctx.lock(); |
| 12809 | } |
| 12810 | // If there are fused ops, just write out timestamps for all nodes to keep the accounting simple |
| 12811 | for (int j = 0; j < ctx->num_additional_fused_ops + 1; ++j) { |
| 12812 | compute_ctx->s->buffer.writeTimestamp(pipelineStage: vk::PipelineStageFlagBits::eAllCommands, queryPool: ctx->device->query_pool, query: i+j+1); |
| 12813 | } |
| 12814 | } |
| 12815 | |
| 12816 | if (enqueued) { |
| 12817 | ++submitted_nodes; |
| 12818 | |
| 12819 | #ifndef GGML_VULKAN_CHECK_RESULTS |
| 12820 | if (first_node_in_batch) { |
| 12821 | first_node_in_batch = false; |
| 12822 | } |
| 12823 | #endif |
| 12824 | } |
| 12825 | |
| 12826 | if (submit && enqueued) { |
| 12827 | first_node_in_batch = true; |
| 12828 | submitted_nodes = 0; |
| 12829 | mul_mat_bytes = 0; |
| 12830 | if (submit_count < 3) { |
| 12831 | mul_mat_bytes_per_submit *= 2; |
| 12832 | } |
| 12833 | submit_count++; |
| 12834 | } |
| 12835 | i += ctx->num_additional_fused_ops; |
| 12836 | ctx->num_additional_fused_ops = 0; |
| 12837 | ctx->fused_ops_write_mask = 0; |
| 12838 | } |
| 12839 | |
| 12840 | ctx->prealloc_size_add_rms_partials = std::max(a: ctx->prealloc_size_add_rms_partials, b: ctx->prealloc_size_add_rms_partials_offset); |
| 12841 | ctx->last_total_mul_mat_bytes = total_mul_mat_bytes; |
| 12842 | |
| 12843 | if (vk_perf_logger_enabled) { |
| 12844 | // End the command buffer and submit/wait |
| 12845 | GGML_ASSERT(!ctx->compute_ctx.expired()); |
| 12846 | compute_ctx = ctx->compute_ctx.lock(); |
| 12847 | ggml_vk_ctx_end(ctx&: compute_ctx); |
| 12848 | |
| 12849 | ggml_vk_submit(ctx&: compute_ctx, fence: ctx->device->fence); |
| 12850 | VK_CHECK(ctx->device->device.waitForFences({ ctx->device->fence }, true, UINT64_MAX), "GGML_VULKAN_PERF waitForFences" ); |
| 12851 | ctx->device->device.resetFences(fences: { ctx->device->fence }); |
| 12852 | |
| 12853 | // Get the results and pass them to the logger |
| 12854 | std::vector<uint64_t> timestamps(cgraph->n_nodes + 1); |
| 12855 | VK_CHECK(ctx->device->device.getQueryPoolResults(ctx->device->query_pool, 0, cgraph->n_nodes + 1, (cgraph->n_nodes + 1)*sizeof(uint64_t), timestamps.data(), sizeof(uint64_t), vk::QueryResultFlagBits::e64 | vk::QueryResultFlagBits::eWait), "get timestamp results" ); |
| 12856 | for (int i = 0; i < cgraph->n_nodes; i++) { |
| 12857 | if (!ggml_vk_is_empty(node: cgraph->nodes[i])) { |
| 12858 | ctx->device->perf_logger->log_timing(node: cgraph->nodes[i], time: uint64_t((timestamps[i+1] - timestamps[i]) * ctx->device->properties.limits.timestampPeriod)); |
| 12859 | } |
| 12860 | } |
| 12861 | |
| 12862 | ctx->device->perf_logger->print_timings(); |
| 12863 | } |
| 12864 | |
| 12865 | ggml_vk_graph_cleanup(ctx); |
| 12866 | |
| 12867 | return GGML_STATUS_SUCCESS; |
| 12868 | |
| 12869 | UNUSED(backend); |
| 12870 | } |
| 12871 | |
| 12872 | // Sort the graph for improved parallelism. |
| 12873 | static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * graph) |
| 12874 | { |
| 12875 | VK_LOG_DEBUG("ggml_vk_graph_optimize(" << graph->n_nodes << " nodes)" ); |
| 12876 | ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; |
| 12877 | |
| 12878 | if (ctx->device->disable_graph_optimize) { |
| 12879 | return; |
| 12880 | } |
| 12881 | |
| 12882 | auto const &is_empty = [](ggml_tensor * node) -> bool { |
| 12883 | return node->op == GGML_OP_NONE || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE; |
| 12884 | }; |
| 12885 | |
| 12886 | auto const &is_src_of = [](const ggml_tensor *dst, const ggml_tensor *src) -> bool { |
| 12887 | for (uint32_t s = 0; s < GGML_MAX_SRC; ++s) { |
| 12888 | if (dst->src[s] == src) { |
| 12889 | return true; |
| 12890 | } |
| 12891 | } |
| 12892 | // implicit dependency if they view the same tensor |
| 12893 | const ggml_tensor *dst2 = dst->view_src ? dst->view_src : dst; |
| 12894 | const ggml_tensor *src2 = src->view_src ? src->view_src : src; |
| 12895 | if (dst2 == src2) { |
| 12896 | return true; |
| 12897 | } |
| 12898 | return false; |
| 12899 | }; |
| 12900 | |
| 12901 | // This function tries to reorder the graph to allow nodes to run in parallel. |
| 12902 | // This helps with small batches, but for large batches its a slowdown, probably |
| 12903 | // due to cache contention. So only reorder if the majority of nodes have few rows. |
| 12904 | int num_small_nodes = 0; |
| 12905 | int num_counted_nodes = 0; |
| 12906 | for (int i = 0; i < graph->n_nodes; ++i) { |
| 12907 | if (!is_empty(graph->nodes[i]) && |
| 12908 | graph->nodes[i]->op != GGML_OP_SET_ROWS) { |
| 12909 | if (ggml_nrows(tensor: graph->nodes[i]) <= 8) { |
| 12910 | num_small_nodes++; |
| 12911 | } |
| 12912 | num_counted_nodes++; |
| 12913 | } |
| 12914 | } |
| 12915 | if (num_small_nodes < num_counted_nodes / 2) { |
| 12916 | return; |
| 12917 | } |
| 12918 | |
| 12919 | std::vector<ggml_tensor *> new_order; |
| 12920 | std::vector<bool> used(graph->n_nodes, false); |
| 12921 | int first_unused = 0; |
| 12922 | while (first_unused < graph->n_nodes) { |
| 12923 | std::vector<int> current_set; |
| 12924 | |
| 12925 | // Check for fusion patterns and avoid reordering them |
| 12926 | auto const &match_pattern = [&](const std::initializer_list<ggml_op> &pattern, int start) -> bool { |
| 12927 | if (start + (int)pattern.size() <= graph->n_nodes) { |
| 12928 | bool is_pattern = true; |
| 12929 | for (size_t j = 0; j < pattern.size(); ++j) { |
| 12930 | if (graph->nodes[start + j]->op != pattern.begin()[j] || used[start + j]) { |
| 12931 | is_pattern = false; |
| 12932 | } |
| 12933 | } |
| 12934 | return is_pattern; |
| 12935 | } |
| 12936 | return false; |
| 12937 | }; |
| 12938 | |
| 12939 | auto const &keep_pattern = [&](const std::initializer_list<ggml_op> &pattern) -> bool { |
| 12940 | if (match_pattern(pattern, first_unused)) { |
| 12941 | for (size_t j = 0; j < pattern.size(); ++j) { |
| 12942 | new_order.push_back(x: graph->nodes[first_unused + j]); |
| 12943 | used[first_unused + j] = true; |
| 12944 | } |
| 12945 | while (first_unused < graph->n_nodes && used[first_unused]) { |
| 12946 | first_unused++; |
| 12947 | } |
| 12948 | return true; |
| 12949 | } |
| 12950 | return false; |
| 12951 | }; |
| 12952 | |
| 12953 | if (keep_pattern(topk_moe_early_softmax_norm)) { |
| 12954 | continue; |
| 12955 | } |
| 12956 | if (keep_pattern(topk_moe_early_softmax)) { |
| 12957 | continue; |
| 12958 | } |
| 12959 | if (keep_pattern(topk_moe_late_softmax)) { |
| 12960 | continue; |
| 12961 | } |
| 12962 | |
| 12963 | // First, grab the next unused node. |
| 12964 | current_set.push_back(x: first_unused); |
| 12965 | |
| 12966 | // Loop through the next N nodes. Grab any that don't depend on other nodes that |
| 12967 | // haven't already been run. Nodes that have already been run have used[i] set |
| 12968 | // to true. Allow nodes that depend on the previous node if it's a fusion pattern |
| 12969 | // that we support (e.g. RMS_NORM + MUL). |
| 12970 | // This first pass only grabs "real" (non-view nodes). Second pass grabs view nodes. |
| 12971 | // The goal is to not interleave real and view nodes in a way that breaks fusion. |
| 12972 | const int NUM_TO_CHECK = 20; |
| 12973 | for (int j = first_unused+1; j < std::min(a: first_unused + NUM_TO_CHECK, b: graph->n_nodes); ++j) { |
| 12974 | if (used[j]) { |
| 12975 | continue; |
| 12976 | } |
| 12977 | if (is_empty(graph->nodes[j])) { |
| 12978 | continue; |
| 12979 | } |
| 12980 | // Don't pull forward nodes from fusion patterns |
| 12981 | if (match_pattern(topk_moe_early_softmax_norm, j) || |
| 12982 | match_pattern(topk_moe_early_softmax, j) || |
| 12983 | match_pattern(topk_moe_late_softmax, j)) { |
| 12984 | continue; |
| 12985 | } |
| 12986 | bool ok = true; |
| 12987 | for (int c = first_unused; c < j; ++c) { |
| 12988 | if (!used[c] && |
| 12989 | is_src_of(graph->nodes[j], graph->nodes[c]) && |
| 12990 | !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_RMS_NORM && graph->nodes[j]->op == GGML_OP_MUL) && |
| 12991 | !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT && graph->nodes[j]->op == GGML_OP_ADD) && |
| 12992 | !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT_ID && graph->nodes[j]->op == GGML_OP_ADD_ID)) { |
| 12993 | ok = false; |
| 12994 | break; |
| 12995 | } |
| 12996 | } |
| 12997 | if (ok) { |
| 12998 | current_set.push_back(x: j); |
| 12999 | |
| 13000 | int rope_idx = j; |
| 13001 | |
| 13002 | // When we've found RMS_NORM + MUL, try to find a ROPE that uses it |
| 13003 | if (j > 0 && |
| 13004 | graph->nodes[j]->op == GGML_OP_MUL && |
| 13005 | graph->nodes[j-1]->op == GGML_OP_RMS_NORM) { |
| 13006 | for (int k = j + 1; k < std::min(a: j + 15, b: graph->n_nodes); ++k) { |
| 13007 | if (graph->nodes[k]->op == GGML_OP_ROPE && |
| 13008 | graph->nodes[k]->src[0] == graph->nodes[j] && |
| 13009 | // Check that other srcs are already valid |
| 13010 | graph->nodes[k]->src[1]->op == GGML_OP_NONE && |
| 13011 | (graph->nodes[k]->src[2] == nullptr || graph->nodes[k]->src[2]->op == GGML_OP_NONE)) { |
| 13012 | rope_idx = k; |
| 13013 | current_set.push_back(x: rope_idx); |
| 13014 | used[rope_idx] = true; |
| 13015 | break; |
| 13016 | } |
| 13017 | } |
| 13018 | } |
| 13019 | // Look for ROPE + VIEW + SET_ROWS and make them consecutive |
| 13020 | if (graph->nodes[rope_idx]->op == GGML_OP_ROPE) { |
| 13021 | int view_idx = -1; |
| 13022 | int set_rows_idx = -1; |
| 13023 | for (int k = rope_idx+1; k < std::min(a: rope_idx + 10, b: graph->n_nodes); ++k) { |
| 13024 | if (view_idx == -1 && |
| 13025 | graph->nodes[k]->op == GGML_OP_VIEW && |
| 13026 | graph->nodes[k]->src[0] == graph->nodes[rope_idx]) { |
| 13027 | view_idx = k; |
| 13028 | continue; |
| 13029 | } |
| 13030 | if (view_idx != -1 && |
| 13031 | set_rows_idx == -1 && |
| 13032 | graph->nodes[k]->op == GGML_OP_SET_ROWS && |
| 13033 | graph->nodes[k]->src[0] == graph->nodes[view_idx]) { |
| 13034 | set_rows_idx = k; |
| 13035 | break; |
| 13036 | } |
| 13037 | } |
| 13038 | if (set_rows_idx != -1) { |
| 13039 | current_set.push_back(x: view_idx); |
| 13040 | current_set.push_back(x: set_rows_idx); |
| 13041 | used[view_idx] = true; |
| 13042 | used[set_rows_idx] = true; |
| 13043 | } |
| 13044 | } |
| 13045 | } |
| 13046 | } |
| 13047 | // Second pass grabs view nodes. |
| 13048 | // Skip this if it would break a fusion optimization (don't split up add->rms_norm or add->add). |
| 13049 | if (graph->nodes[current_set.back()]->op != GGML_OP_ADD) { |
| 13050 | for (int j = first_unused+1; j < std::min(a: first_unused + NUM_TO_CHECK, b: graph->n_nodes); ++j) { |
| 13051 | if (used[j]) { |
| 13052 | continue; |
| 13053 | } |
| 13054 | if (!is_empty(graph->nodes[j])) { |
| 13055 | continue; |
| 13056 | } |
| 13057 | bool ok = true; |
| 13058 | for (int c = first_unused; c < j; ++c) { |
| 13059 | bool c_in_current_set = std::find(first: current_set.begin(), last: current_set.end(), val: c) != current_set.end(); |
| 13060 | // skip views whose srcs haven't been processed. |
| 13061 | if (!used[c] && |
| 13062 | is_src_of(graph->nodes[j], graph->nodes[c]) && |
| 13063 | !c_in_current_set) { |
| 13064 | ok = false; |
| 13065 | break; |
| 13066 | } |
| 13067 | } |
| 13068 | if (ok) { |
| 13069 | current_set.push_back(x: j); |
| 13070 | } |
| 13071 | } |
| 13072 | } |
| 13073 | |
| 13074 | // Push the current set into new_order |
| 13075 | for (auto c : current_set) { |
| 13076 | new_order.push_back(x: graph->nodes[c]); |
| 13077 | used[c] = true; |
| 13078 | } |
| 13079 | while (first_unused < graph->n_nodes && used[first_unused]) { |
| 13080 | first_unused++; |
| 13081 | } |
| 13082 | } |
| 13083 | // Replace the graph with the new order. |
| 13084 | for (int i = 0; i < graph->n_nodes; ++i) { |
| 13085 | graph->nodes[i] = new_order[i]; |
| 13086 | } |
| 13087 | } |
| 13088 | |
| 13089 | // TODO: enable async and synchronize |
| 13090 | static ggml_backend_i ggml_backend_vk_interface = { |
| 13091 | /* .get_name = */ ggml_backend_vk_name, |
| 13092 | /* .free = */ ggml_backend_vk_free, |
| 13093 | /* .set_tensor_async = */ NULL, // ggml_backend_vk_set_tensor_async, |
| 13094 | /* .get_tensor_async = */ NULL, // ggml_backend_vk_get_tensor_async, |
| 13095 | /* .cpy_tensor_async = */ NULL, // ggml_backend_vk_cpy_tensor_async, |
| 13096 | /* .synchronize = */ NULL, // ggml_backend_vk_synchronize, |
| 13097 | /* .graph_plan_create = */ NULL, |
| 13098 | /* .graph_plan_free = */ NULL, |
| 13099 | /* .graph_plan_update = */ NULL, |
| 13100 | /* .graph_plan_compute = */ NULL, |
| 13101 | /* .graph_compute = */ ggml_backend_vk_graph_compute, |
| 13102 | /* .event_record = */ NULL, |
| 13103 | /* .event_wait = */ NULL, |
| 13104 | /* .graph_optimize = */ ggml_vk_graph_optimize, |
| 13105 | }; |
| 13106 | |
| 13107 | static ggml_guid_t ggml_backend_vk_guid() { |
| 13108 | static ggml_guid guid = { 0xb8, 0xf7, 0x4f, 0x86, 0x40, 0x3c, 0xe1, 0x02, 0x91, 0xc8, 0xdd, 0xe9, 0x02, 0x3f, 0xc0, 0x2b }; |
| 13109 | return &guid; |
| 13110 | } |
| 13111 | |
| 13112 | ggml_backend_t ggml_backend_vk_init(size_t dev_num) { |
| 13113 | VK_LOG_DEBUG("ggml_backend_vk_init(" << dev_num << ")" ); |
| 13114 | |
| 13115 | ggml_backend_vk_context * ctx = new ggml_backend_vk_context; |
| 13116 | ggml_vk_init(ctx, idx: dev_num); |
| 13117 | |
| 13118 | ggml_backend_t vk_backend = new ggml_backend { |
| 13119 | /* .guid = */ ggml_backend_vk_guid(), |
| 13120 | /* .iface = */ ggml_backend_vk_interface, |
| 13121 | /* .device = */ ggml_backend_reg_dev_get(reg: ggml_backend_vk_reg(), index: dev_num), |
| 13122 | /* .context = */ ctx, |
| 13123 | }; |
| 13124 | |
| 13125 | return vk_backend; |
| 13126 | } |
| 13127 | |
| 13128 | bool ggml_backend_is_vk(ggml_backend_t backend) { |
| 13129 | return backend != NULL && ggml_guid_matches(guid_a: backend->guid, guid_b: ggml_backend_vk_guid()); |
| 13130 | } |
| 13131 | |
| 13132 | int ggml_backend_vk_get_device_count() { |
| 13133 | return ggml_vk_get_device_count(); |
| 13134 | } |
| 13135 | |
| 13136 | void ggml_backend_vk_get_device_description(int device, char * description, size_t description_size) { |
| 13137 | GGML_ASSERT(device < (int) vk_instance.device_indices.size()); |
| 13138 | int dev_idx = vk_instance.device_indices[device]; |
| 13139 | ggml_vk_get_device_description(device: dev_idx, description, description_size); |
| 13140 | } |
| 13141 | |
| 13142 | void ggml_backend_vk_get_device_memory(int device, size_t * free, size_t * total) { |
| 13143 | GGML_ASSERT(device < (int) vk_instance.device_indices.size()); |
| 13144 | GGML_ASSERT(device < (int) vk_instance.device_supports_membudget.size()); |
| 13145 | |
| 13146 | vk::PhysicalDevice vkdev = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[device]]; |
| 13147 | vk::PhysicalDeviceMemoryBudgetPropertiesEXT budgetprops; |
| 13148 | vk::PhysicalDeviceMemoryProperties2 memprops = {}; |
| 13149 | bool membudget_supported = vk_instance.device_supports_membudget[device]; |
| 13150 | |
| 13151 | if (membudget_supported) { |
| 13152 | memprops.pNext = &budgetprops; |
| 13153 | } |
| 13154 | vkdev.getMemoryProperties2(pMemoryProperties: &memprops); |
| 13155 | |
| 13156 | for (uint32_t i = 0; i < memprops.memoryProperties.memoryHeapCount; ++i) { |
| 13157 | const vk::MemoryHeap & heap = memprops.memoryProperties.memoryHeaps[i]; |
| 13158 | |
| 13159 | if (heap.flags & vk::MemoryHeapFlagBits::eDeviceLocal) { |
| 13160 | *total = heap.size; |
| 13161 | |
| 13162 | if (membudget_supported && i < budgetprops.heapUsage.size()) { |
| 13163 | *free = budgetprops.heapBudget[i] - budgetprops.heapUsage[i]; |
| 13164 | } else { |
| 13165 | *free = heap.size; |
| 13166 | } |
| 13167 | break; |
| 13168 | } |
| 13169 | } |
| 13170 | } |
| 13171 | |
| 13172 | static vk::PhysicalDeviceType ggml_backend_vk_get_device_type(int device_idx) { |
| 13173 | GGML_ASSERT(device_idx >= 0 && device_idx < (int) vk_instance.device_indices.size()); |
| 13174 | |
| 13175 | vk::PhysicalDevice device = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[device_idx]]; |
| 13176 | |
| 13177 | vk::PhysicalDeviceProperties2 props = {}; |
| 13178 | device.getProperties2(pProperties: &props); |
| 13179 | |
| 13180 | return props.properties.deviceType; |
| 13181 | } |
| 13182 | |
| 13183 | static std::string ggml_backend_vk_get_device_pci_id(int device_idx) { |
| 13184 | GGML_ASSERT(device_idx >= 0 && device_idx < (int) vk_instance.device_indices.size()); |
| 13185 | |
| 13186 | vk::PhysicalDevice device = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[device_idx]]; |
| 13187 | |
| 13188 | const std::vector<vk::ExtensionProperties> ext_props = device.enumerateDeviceExtensionProperties(); |
| 13189 | |
| 13190 | bool ext_support = false; |
| 13191 | |
| 13192 | for (const auto& properties : ext_props) { |
| 13193 | if (strcmp(s1: "VK_EXT_pci_bus_info" , s2: properties.extensionName) == 0) { |
| 13194 | ext_support = true; |
| 13195 | break; |
| 13196 | } |
| 13197 | } |
| 13198 | |
| 13199 | if (!ext_support) { |
| 13200 | return "" ; |
| 13201 | } |
| 13202 | |
| 13203 | vk::PhysicalDeviceProperties2 props = {}; |
| 13204 | vk::PhysicalDevicePCIBusInfoPropertiesEXT pci_bus_info = {}; |
| 13205 | |
| 13206 | props.pNext = &pci_bus_info; |
| 13207 | |
| 13208 | device.getProperties2(pProperties: &props); |
| 13209 | |
| 13210 | const uint32_t pci_domain = pci_bus_info.pciDomain; |
| 13211 | const uint32_t pci_bus = pci_bus_info.pciBus; |
| 13212 | const uint32_t pci_device = pci_bus_info.pciDevice; |
| 13213 | const uint8_t pci_function = (uint8_t) pci_bus_info.pciFunction; // pci function is between 0 and 7, prevent printf overflow warning |
| 13214 | |
| 13215 | char pci_bus_id[16] = {}; |
| 13216 | snprintf(s: pci_bus_id, maxlen: sizeof(pci_bus_id), format: "%04x:%02x:%02x.%x" , pci_domain, pci_bus, pci_device, pci_function); |
| 13217 | |
| 13218 | return std::string(pci_bus_id); |
| 13219 | } |
| 13220 | |
| 13221 | ////////////////////////// |
| 13222 | |
| 13223 | struct ggml_backend_vk_device_context { |
| 13224 | size_t device; |
| 13225 | std::string name; |
| 13226 | std::string description; |
| 13227 | bool is_integrated_gpu; |
| 13228 | std::string pci_bus_id; |
| 13229 | }; |
| 13230 | |
| 13231 | static const char * ggml_backend_vk_device_get_name(ggml_backend_dev_t dev) { |
| 13232 | ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; |
| 13233 | return ctx->name.c_str(); |
| 13234 | } |
| 13235 | |
| 13236 | static const char * ggml_backend_vk_device_get_description(ggml_backend_dev_t dev) { |
| 13237 | ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; |
| 13238 | return ctx->description.c_str(); |
| 13239 | } |
| 13240 | |
| 13241 | static void ggml_backend_vk_device_get_memory(ggml_backend_dev_t device, size_t * free, size_t * total) { |
| 13242 | ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)device->context; |
| 13243 | ggml_backend_vk_get_device_memory(device: ctx->device, free, total); |
| 13244 | } |
| 13245 | |
| 13246 | static ggml_backend_buffer_type_t ggml_backend_vk_device_get_buffer_type(ggml_backend_dev_t dev) { |
| 13247 | ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; |
| 13248 | return ggml_backend_vk_buffer_type(dev_num: ctx->device); |
| 13249 | } |
| 13250 | |
| 13251 | static ggml_backend_buffer_type_t ggml_backend_vk_device_get_host_buffer_type(ggml_backend_dev_t dev) { |
| 13252 | UNUSED(dev); |
| 13253 | return ggml_backend_vk_host_buffer_type(); |
| 13254 | } |
| 13255 | |
| 13256 | static enum ggml_backend_dev_type ggml_backend_vk_device_get_type(ggml_backend_dev_t dev) { |
| 13257 | ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; |
| 13258 | |
| 13259 | return ctx->is_integrated_gpu ? GGML_BACKEND_DEVICE_TYPE_IGPU : GGML_BACKEND_DEVICE_TYPE_GPU; |
| 13260 | } |
| 13261 | |
| 13262 | static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) { |
| 13263 | ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; |
| 13264 | |
| 13265 | props->name = ggml_backend_vk_device_get_name(dev); |
| 13266 | props->description = ggml_backend_vk_device_get_description(dev); |
| 13267 | props->type = ggml_backend_vk_device_get_type(dev); |
| 13268 | props->device_id = ctx->pci_bus_id.empty() ? nullptr : ctx->pci_bus_id.c_str(); |
| 13269 | ggml_backend_vk_device_get_memory(device: dev, free: &props->memory_free, total: &props->memory_total); |
| 13270 | props->caps = { |
| 13271 | /* .async = */ false, |
| 13272 | /* .host_buffer = */ true, |
| 13273 | /* .buffer_from_host_ptr = */ false, |
| 13274 | /* .events = */ false, |
| 13275 | }; |
| 13276 | } |
| 13277 | |
| 13278 | static ggml_backend_t ggml_backend_vk_device_init(ggml_backend_dev_t dev, const char * params) { |
| 13279 | UNUSED(params); |
| 13280 | ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; |
| 13281 | return ggml_backend_vk_init(dev_num: ctx->device); |
| 13282 | } |
| 13283 | |
| 13284 | static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) { |
| 13285 | switch (op->op) { |
| 13286 | case GGML_OP_UNARY: |
| 13287 | switch (ggml_get_unary_op(tensor: op)) { |
| 13288 | case GGML_UNARY_OP_EXP: |
| 13289 | case GGML_UNARY_OP_GELU: |
| 13290 | case GGML_UNARY_OP_GELU_ERF: |
| 13291 | case GGML_UNARY_OP_GELU_QUICK: |
| 13292 | case GGML_UNARY_OP_SILU: |
| 13293 | case GGML_UNARY_OP_RELU: |
| 13294 | case GGML_UNARY_OP_TANH: |
| 13295 | case GGML_UNARY_OP_SIGMOID: |
| 13296 | case GGML_UNARY_OP_HARDSIGMOID: |
| 13297 | case GGML_UNARY_OP_HARDSWISH: |
| 13298 | return ggml_is_contiguous(tensor: op->src[0]) && |
| 13299 | (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && |
| 13300 | (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && |
| 13301 | (op->src[0]->type == op->type); |
| 13302 | default: |
| 13303 | return false; |
| 13304 | } |
| 13305 | case GGML_OP_GLU: |
| 13306 | switch (ggml_get_glu_op(tensor: op)) { |
| 13307 | case GGML_GLU_OP_GEGLU: |
| 13308 | case GGML_GLU_OP_REGLU: |
| 13309 | case GGML_GLU_OP_SWIGLU: |
| 13310 | case GGML_GLU_OP_SWIGLU_OAI: |
| 13311 | case GGML_GLU_OP_GEGLU_ERF: |
| 13312 | case GGML_GLU_OP_GEGLU_QUICK: |
| 13313 | return ggml_is_contiguous(tensor: op->src[0]) && |
| 13314 | (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && |
| 13315 | (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && |
| 13316 | (op->src[0]->type == op->type); |
| 13317 | default: |
| 13318 | return false; |
| 13319 | } |
| 13320 | case GGML_OP_MUL_MAT: |
| 13321 | case GGML_OP_MUL_MAT_ID: |
| 13322 | { |
| 13323 | ggml_type src0_type = op->src[0]->type; |
| 13324 | ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; |
| 13325 | const vk_device& device = ggml_vk_get_device(idx: ctx->device); |
| 13326 | if (op->op == GGML_OP_MUL_MAT_ID) { |
| 13327 | if (!device->mul_mat_id_s[src0_type] && !device->mul_mat_id_m[src0_type] && !device->mul_mat_id_l[src0_type]) { |
| 13328 | // If there's not enough shared memory for row_ids and the result tile, fallback to CPU |
| 13329 | return false; |
| 13330 | } |
| 13331 | } |
| 13332 | switch (src0_type) { |
| 13333 | case GGML_TYPE_F32: |
| 13334 | case GGML_TYPE_F16: |
| 13335 | case GGML_TYPE_BF16: |
| 13336 | case GGML_TYPE_Q4_0: |
| 13337 | case GGML_TYPE_Q4_1: |
| 13338 | case GGML_TYPE_Q5_0: |
| 13339 | case GGML_TYPE_Q5_1: |
| 13340 | case GGML_TYPE_Q8_0: |
| 13341 | case GGML_TYPE_Q2_K: |
| 13342 | case GGML_TYPE_Q3_K: |
| 13343 | case GGML_TYPE_Q4_K: |
| 13344 | case GGML_TYPE_Q5_K: |
| 13345 | case GGML_TYPE_Q6_K: |
| 13346 | case GGML_TYPE_IQ1_S: |
| 13347 | case GGML_TYPE_IQ1_M: |
| 13348 | case GGML_TYPE_IQ2_XXS: |
| 13349 | case GGML_TYPE_IQ2_XS: |
| 13350 | case GGML_TYPE_IQ2_S: |
| 13351 | case GGML_TYPE_IQ3_XXS: |
| 13352 | case GGML_TYPE_IQ3_S: |
| 13353 | case GGML_TYPE_IQ4_XS: |
| 13354 | case GGML_TYPE_IQ4_NL: |
| 13355 | case GGML_TYPE_MXFP4: |
| 13356 | break; |
| 13357 | default: |
| 13358 | return false; |
| 13359 | } |
| 13360 | struct ggml_tensor * a; |
| 13361 | struct ggml_tensor * b; |
| 13362 | if (op->op == GGML_OP_MUL_MAT) { |
| 13363 | a = op->src[0]; |
| 13364 | b = op->src[1]; |
| 13365 | } else { |
| 13366 | a = op->src[2]; |
| 13367 | b = op->src[1]; |
| 13368 | } |
| 13369 | if (a->ne[3] != b->ne[3]) { |
| 13370 | return false; |
| 13371 | } |
| 13372 | if (!(ggml_vk_dim01_contiguous(tensor: op->src[0]) || op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_BF16) || |
| 13373 | !(ggml_vk_dim01_contiguous(tensor: op->src[1]) || op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16)) { |
| 13374 | return false; |
| 13375 | } |
| 13376 | if (op->src[0]->type == GGML_TYPE_BF16 && op->src[1]->type == GGML_TYPE_F16) { |
| 13377 | // We currently don't have a bf16 x f16 shader, or an fp16->bf16 copy shader. |
| 13378 | // So don't support this combination for now. |
| 13379 | return false; |
| 13380 | } |
| 13381 | |
| 13382 | return true; |
| 13383 | } |
| 13384 | case GGML_OP_FLASH_ATTN_EXT: |
| 13385 | { |
| 13386 | ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; |
| 13387 | auto device = ggml_vk_get_device(idx: ctx->device); |
| 13388 | bool coopmat2 = device->coopmat2; |
| 13389 | uint32_t HSK = op->src[1]->ne[0]; |
| 13390 | uint32_t HSV = op->src[2]->ne[0]; |
| 13391 | if ((HSK % 8) != 0 || (HSV % 8) != 0) { |
| 13392 | return false; |
| 13393 | } |
| 13394 | if (op->src[4] && op->src[4]->type != GGML_TYPE_F32) { |
| 13395 | return false; |
| 13396 | } |
| 13397 | if (op->src[0]->type != GGML_TYPE_F32) { |
| 13398 | return false; |
| 13399 | } |
| 13400 | if (op->type != GGML_TYPE_F32) { |
| 13401 | return false; |
| 13402 | } |
| 13403 | if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) { |
| 13404 | return false; |
| 13405 | } |
| 13406 | // It's straightforward to support different K/V dequant, but would |
| 13407 | // significantly increase the number of pipelines |
| 13408 | if (op->src[1]->type != op->src[2]->type) { |
| 13409 | return false; |
| 13410 | } |
| 13411 | switch (op->src[1]->type) { |
| 13412 | case GGML_TYPE_F16: |
| 13413 | case GGML_TYPE_F32: |
| 13414 | case GGML_TYPE_Q4_0: |
| 13415 | case GGML_TYPE_Q8_0: |
| 13416 | // supported in scalar and coopmat2 paths |
| 13417 | break; |
| 13418 | case GGML_TYPE_Q4_1: |
| 13419 | case GGML_TYPE_Q5_0: |
| 13420 | case GGML_TYPE_Q5_1: |
| 13421 | // K dequants currently disabled because D dimension is rounded up to 256 and runs inefficiently |
| 13422 | //case GGML_TYPE_Q2_K: |
| 13423 | //case GGML_TYPE_Q3_K: |
| 13424 | //case GGML_TYPE_Q4_K: |
| 13425 | //case GGML_TYPE_Q5_K: |
| 13426 | //case GGML_TYPE_Q6_K: |
| 13427 | //case GGML_TYPE_IQ1_S: |
| 13428 | //case GGML_TYPE_IQ1_M: |
| 13429 | //case GGML_TYPE_IQ2_XXS: |
| 13430 | //case GGML_TYPE_IQ2_XS: |
| 13431 | //case GGML_TYPE_IQ2_S: |
| 13432 | //case GGML_TYPE_IQ3_XXS: |
| 13433 | //case GGML_TYPE_IQ3_S: |
| 13434 | //case GGML_TYPE_IQ4_XS: |
| 13435 | case GGML_TYPE_IQ4_NL: |
| 13436 | // currently supported only in coopmat2 path |
| 13437 | if (!coopmat2) { |
| 13438 | return false; |
| 13439 | } |
| 13440 | break; |
| 13441 | default: |
| 13442 | return false; |
| 13443 | } |
| 13444 | if (!coopmat2 && !device->subgroup_shuffle) { |
| 13445 | // scalar FA uses subgroupShuffle |
| 13446 | return false; |
| 13447 | } |
| 13448 | return true; |
| 13449 | } |
| 13450 | case GGML_OP_GET_ROWS: |
| 13451 | { |
| 13452 | switch (op->src[0]->type) { |
| 13453 | case GGML_TYPE_F32: |
| 13454 | case GGML_TYPE_F16: |
| 13455 | case GGML_TYPE_BF16: |
| 13456 | case GGML_TYPE_Q4_0: |
| 13457 | case GGML_TYPE_Q4_1: |
| 13458 | case GGML_TYPE_Q5_0: |
| 13459 | case GGML_TYPE_Q5_1: |
| 13460 | case GGML_TYPE_Q8_0: |
| 13461 | case GGML_TYPE_Q2_K: |
| 13462 | case GGML_TYPE_Q3_K: |
| 13463 | case GGML_TYPE_Q4_K: |
| 13464 | case GGML_TYPE_Q5_K: |
| 13465 | case GGML_TYPE_Q6_K: |
| 13466 | case GGML_TYPE_IQ1_S: |
| 13467 | case GGML_TYPE_IQ1_M: |
| 13468 | case GGML_TYPE_IQ2_XXS: |
| 13469 | case GGML_TYPE_IQ2_XS: |
| 13470 | case GGML_TYPE_IQ2_S: |
| 13471 | case GGML_TYPE_IQ3_XXS: |
| 13472 | case GGML_TYPE_IQ3_S: |
| 13473 | case GGML_TYPE_IQ4_XS: |
| 13474 | case GGML_TYPE_IQ4_NL: |
| 13475 | case GGML_TYPE_MXFP4: |
| 13476 | return true; |
| 13477 | default: |
| 13478 | return false; |
| 13479 | } |
| 13480 | } |
| 13481 | case GGML_OP_SET_ROWS: |
| 13482 | { |
| 13483 | switch (op->type) { |
| 13484 | case GGML_TYPE_F32: |
| 13485 | case GGML_TYPE_F16: |
| 13486 | case GGML_TYPE_BF16: |
| 13487 | case GGML_TYPE_Q4_0: |
| 13488 | case GGML_TYPE_Q4_1: |
| 13489 | case GGML_TYPE_Q5_0: |
| 13490 | case GGML_TYPE_Q5_1: |
| 13491 | case GGML_TYPE_Q8_0: |
| 13492 | case GGML_TYPE_IQ4_NL: |
| 13493 | return true; |
| 13494 | default: |
| 13495 | return false; |
| 13496 | } |
| 13497 | } |
| 13498 | case GGML_OP_CONT: |
| 13499 | case GGML_OP_CPY: |
| 13500 | case GGML_OP_DUP: |
| 13501 | { |
| 13502 | ggml_type src0_type = op->src[0]->type; |
| 13503 | ggml_type src1_type = op->src[1] != nullptr ? op->src[1]->type : src0_type; |
| 13504 | |
| 13505 | if (src0_type == GGML_TYPE_F32) { |
| 13506 | switch (src1_type) { |
| 13507 | case GGML_TYPE_F32: |
| 13508 | case GGML_TYPE_F16: |
| 13509 | case GGML_TYPE_BF16: |
| 13510 | case GGML_TYPE_Q4_0: |
| 13511 | case GGML_TYPE_Q4_1: |
| 13512 | case GGML_TYPE_Q5_0: |
| 13513 | case GGML_TYPE_Q5_1: |
| 13514 | case GGML_TYPE_Q8_0: |
| 13515 | case GGML_TYPE_IQ4_NL: |
| 13516 | return true; |
| 13517 | default: |
| 13518 | break; |
| 13519 | } |
| 13520 | } |
| 13521 | if (src1_type == GGML_TYPE_F32) { |
| 13522 | switch (src0_type) { |
| 13523 | case GGML_TYPE_F16: |
| 13524 | case GGML_TYPE_Q4_0: |
| 13525 | case GGML_TYPE_Q4_1: |
| 13526 | case GGML_TYPE_Q5_0: |
| 13527 | case GGML_TYPE_Q5_1: |
| 13528 | case GGML_TYPE_Q8_0: |
| 13529 | case GGML_TYPE_IQ4_NL: |
| 13530 | return true; |
| 13531 | default: |
| 13532 | break; |
| 13533 | } |
| 13534 | } |
| 13535 | |
| 13536 | if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) { |
| 13537 | return true; |
| 13538 | } |
| 13539 | |
| 13540 | if ( |
| 13541 | (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_I32) || |
| 13542 | (src0_type == GGML_TYPE_I32 && src1_type == GGML_TYPE_F32) |
| 13543 | ) { |
| 13544 | return true; |
| 13545 | } |
| 13546 | |
| 13547 | // We can handle copying from a type to the same type if it's |
| 13548 | // contiguous (memcpy). We use f16 or f32 shaders to do the copy, |
| 13549 | // so the type/block size must be a multiple of 4. |
| 13550 | if (src0_type == src1_type && |
| 13551 | ggml_is_contiguous(tensor: op->src[0]) && ggml_is_contiguous(tensor: op) && |
| 13552 | (ggml_type_size(type: src0_type) % 2) == 0) { |
| 13553 | return true; |
| 13554 | } |
| 13555 | return false; |
| 13556 | } |
| 13557 | case GGML_OP_REPEAT: |
| 13558 | return ggml_type_size(type: op->type) == sizeof(float) && ggml_type_size(type: op->src[0]->type) == sizeof(float); |
| 13559 | case GGML_OP_REPEAT_BACK: |
| 13560 | return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32; |
| 13561 | case GGML_OP_ROPE: |
| 13562 | case GGML_OP_ROPE_BACK: |
| 13563 | case GGML_OP_NONE: |
| 13564 | case GGML_OP_RESHAPE: |
| 13565 | case GGML_OP_VIEW: |
| 13566 | case GGML_OP_PERMUTE: |
| 13567 | case GGML_OP_TRANSPOSE: |
| 13568 | case GGML_OP_RMS_NORM: |
| 13569 | return true; |
| 13570 | case GGML_OP_NORM: |
| 13571 | case GGML_OP_GROUP_NORM: |
| 13572 | case GGML_OP_L2_NORM: |
| 13573 | return ggml_is_contiguous(tensor: op->src[0]); |
| 13574 | case GGML_OP_ADD: |
| 13575 | case GGML_OP_SUB: |
| 13576 | case GGML_OP_MUL: |
| 13577 | case GGML_OP_DIV: |
| 13578 | return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && |
| 13579 | (op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16) && |
| 13580 | (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16); |
| 13581 | case GGML_OP_ADD_ID: |
| 13582 | return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->src[2]->type == GGML_TYPE_I32 && |
| 13583 | op->type == GGML_TYPE_F32; |
| 13584 | case GGML_OP_SILU_BACK: |
| 13585 | case GGML_OP_RMS_NORM_BACK: |
| 13586 | case GGML_OP_SQR: |
| 13587 | case GGML_OP_SQRT: |
| 13588 | case GGML_OP_SIN: |
| 13589 | case GGML_OP_COS: |
| 13590 | case GGML_OP_CLAMP: |
| 13591 | case GGML_OP_LEAKY_RELU: |
| 13592 | case GGML_OP_OPT_STEP_ADAMW: |
| 13593 | case GGML_OP_OPT_STEP_SGD: |
| 13594 | return op->src[0]->type == GGML_TYPE_F32; |
| 13595 | case GGML_OP_ARGSORT: |
| 13596 | return op->ne[0] <= max_argsort_cols; |
| 13597 | case GGML_OP_UPSCALE: |
| 13598 | case GGML_OP_ACC: |
| 13599 | case GGML_OP_CONCAT: |
| 13600 | case GGML_OP_SCALE: |
| 13601 | case GGML_OP_PAD: |
| 13602 | case GGML_OP_ROLL: |
| 13603 | case GGML_OP_DIAG_MASK_INF: |
| 13604 | case GGML_OP_SOFT_MAX: |
| 13605 | case GGML_OP_SOFT_MAX_BACK: |
| 13606 | return true; |
| 13607 | case GGML_OP_SUM: |
| 13608 | case GGML_OP_SUM_ROWS: |
| 13609 | case GGML_OP_MEAN: |
| 13610 | return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(tensor: op->src[0]); |
| 13611 | case GGML_OP_ARGMAX: |
| 13612 | case GGML_OP_COUNT_EQUAL: |
| 13613 | case GGML_OP_IM2COL: |
| 13614 | case GGML_OP_IM2COL_3D: |
| 13615 | case GGML_OP_TIMESTEP_EMBEDDING: |
| 13616 | case GGML_OP_CONV_2D_DW: |
| 13617 | case GGML_OP_POOL_2D: |
| 13618 | case GGML_OP_RWKV_WKV6: |
| 13619 | case GGML_OP_RWKV_WKV7: |
| 13620 | return true; |
| 13621 | case GGML_OP_SSM_SCAN: |
| 13622 | { |
| 13623 | for (int i = 0; i < 6; i++) { |
| 13624 | if (op->src[i] && ggml_is_quantized(type: op->src[i]->type)) { |
| 13625 | return false; |
| 13626 | } |
| 13627 | } |
| 13628 | if (op->src[6] && op->src[6]->type != GGML_TYPE_I32) { |
| 13629 | return false; |
| 13630 | } |
| 13631 | if (op->src[0]->type != GGML_TYPE_F32 || op->type != GGML_TYPE_F32) { |
| 13632 | return false; |
| 13633 | } |
| 13634 | |
| 13635 | const uint32_t d_state = op->src[0]->ne[0]; |
| 13636 | const uint32_t head_dim = op->src[0]->ne[1]; |
| 13637 | |
| 13638 | bool is_mamba2 = (op->src[3] && op->src[3]->nb[1] == sizeof(float)); |
| 13639 | if (!is_mamba2) { |
| 13640 | return false; |
| 13641 | } |
| 13642 | |
| 13643 | if ((d_state != 128 && d_state != 256) || head_dim % 16 != 0) { |
| 13644 | return false; |
| 13645 | } |
| 13646 | |
| 13647 | ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; |
| 13648 | const vk_device& device = ggml_vk_get_device(idx: ctx->device); |
| 13649 | |
| 13650 | const uint32_t SPLIT_H = 16; |
| 13651 | |
| 13652 | size_t stateC_size = SPLIT_H * d_state * sizeof(float); |
| 13653 | |
| 13654 | if (stateC_size > device->properties.limits.maxComputeSharedMemorySize) { |
| 13655 | return false; |
| 13656 | } |
| 13657 | |
| 13658 | return true; |
| 13659 | } |
| 13660 | case GGML_OP_SSM_CONV: |
| 13661 | return true; |
| 13662 | case GGML_OP_CONV_TRANSPOSE_1D: |
| 13663 | return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32; |
| 13664 | case GGML_OP_CONV_2D: |
| 13665 | case GGML_OP_CONV_TRANSPOSE_2D: |
| 13666 | { |
| 13667 | // Op is disabled for Apple because it segfaults at pipeline create time on MoltenVK |
| 13668 | ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; |
| 13669 | const vk_device& device = ggml_vk_get_device(idx: ctx->device); |
| 13670 | if (op->op == GGML_OP_CONV_TRANSPOSE_2D && |
| 13671 | device->properties.limits.maxPushConstantsSize < sizeof(vk_op_conv_transpose_2d_push_constants)) { |
| 13672 | return false; |
| 13673 | } |
| 13674 | // Channel-contiguous format is not supported yet. |
| 13675 | return ((op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && |
| 13676 | op->src[1]->type == GGML_TYPE_F32 && |
| 13677 | op->type == GGML_TYPE_F32 && |
| 13678 | ggml_is_contiguous(tensor: op->src[0]) && |
| 13679 | ggml_is_contiguous(tensor: op->src[1]) && |
| 13680 | ggml_is_contiguous(tensor: op)); |
| 13681 | } |
| 13682 | default: |
| 13683 | return false; |
| 13684 | } |
| 13685 | |
| 13686 | UNUSED(dev); |
| 13687 | } |
| 13688 | |
| 13689 | static bool ggml_backend_vk_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) { |
| 13690 | if (buft->iface.get_name != ggml_backend_vk_buffer_type_name) { |
| 13691 | return false; |
| 13692 | } |
| 13693 | |
| 13694 | ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; |
| 13695 | ggml_backend_vk_buffer_type_context * buft_ctx = (ggml_backend_vk_buffer_type_context *)buft->context; |
| 13696 | |
| 13697 | return buft_ctx->device->idx == ctx->device; |
| 13698 | } |
| 13699 | |
| 13700 | static bool ggml_backend_vk_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) { |
| 13701 | const int min_batch_size = 32; |
| 13702 | |
| 13703 | return (op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS) || |
| 13704 | (op->ne[2] >= min_batch_size && op->op == GGML_OP_MUL_MAT_ID); |
| 13705 | |
| 13706 | UNUSED(dev); |
| 13707 | } |
| 13708 | |
| 13709 | static const struct ggml_backend_device_i ggml_backend_vk_device_i = { |
| 13710 | /* .get_name = */ ggml_backend_vk_device_get_name, |
| 13711 | /* .get_description = */ ggml_backend_vk_device_get_description, |
| 13712 | /* .get_memory = */ ggml_backend_vk_device_get_memory, |
| 13713 | /* .get_type = */ ggml_backend_vk_device_get_type, |
| 13714 | /* .get_props = */ ggml_backend_vk_device_get_props, |
| 13715 | /* .init_backend = */ ggml_backend_vk_device_init, |
| 13716 | /* .get_buffer_type = */ ggml_backend_vk_device_get_buffer_type, |
| 13717 | /* .get_host_buffer_type = */ ggml_backend_vk_device_get_host_buffer_type, |
| 13718 | /* .buffer_from_host_ptr = */ NULL, |
| 13719 | /* .supports_op = */ ggml_backend_vk_device_supports_op, |
| 13720 | /* .supports_buft = */ ggml_backend_vk_device_supports_buft, |
| 13721 | /* .offload_op = */ ggml_backend_vk_device_offload_op, |
| 13722 | /* .event_new = */ NULL, |
| 13723 | /* .event_free = */ NULL, |
| 13724 | /* .event_synchronize = */ NULL, |
| 13725 | }; |
| 13726 | |
| 13727 | static const char * ggml_backend_vk_reg_get_name(ggml_backend_reg_t reg) { |
| 13728 | UNUSED(reg); |
| 13729 | return GGML_VK_NAME; |
| 13730 | } |
| 13731 | |
| 13732 | static size_t ggml_backend_vk_reg_get_device_count(ggml_backend_reg_t reg) { |
| 13733 | UNUSED(reg); |
| 13734 | return ggml_backend_vk_get_device_count(); |
| 13735 | } |
| 13736 | |
| 13737 | static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, size_t device) { |
| 13738 | static std::vector<ggml_backend_dev_t> devices; |
| 13739 | |
| 13740 | static bool initialized = false; |
| 13741 | |
| 13742 | { |
| 13743 | static std::mutex mutex; |
| 13744 | std::lock_guard<std::mutex> lock(mutex); |
| 13745 | if (!initialized) { |
| 13746 | for (int i = 0; i < ggml_backend_vk_get_device_count(); i++) { |
| 13747 | ggml_backend_vk_device_context * ctx = new ggml_backend_vk_device_context; |
| 13748 | char desc[256]; |
| 13749 | ggml_backend_vk_get_device_description(device: i, description: desc, description_size: sizeof(desc)); |
| 13750 | ctx->device = i; |
| 13751 | ctx->name = GGML_VK_NAME + std::to_string(val: i); |
| 13752 | ctx->description = desc; |
| 13753 | ctx->is_integrated_gpu = ggml_backend_vk_get_device_type(device_idx: i) == vk::PhysicalDeviceType::eIntegratedGpu; |
| 13754 | ctx->pci_bus_id = ggml_backend_vk_get_device_pci_id(device_idx: i); |
| 13755 | devices.push_back(x: new ggml_backend_device { |
| 13756 | /* .iface = */ ggml_backend_vk_device_i, |
| 13757 | /* .reg = */ reg, |
| 13758 | /* .context = */ ctx, |
| 13759 | }); |
| 13760 | } |
| 13761 | initialized = true; |
| 13762 | } |
| 13763 | } |
| 13764 | |
| 13765 | GGML_ASSERT(device < devices.size()); |
| 13766 | return devices[device]; |
| 13767 | } |
| 13768 | |
| 13769 | static const struct ggml_backend_reg_i ggml_backend_vk_reg_i = { |
| 13770 | /* .get_name = */ ggml_backend_vk_reg_get_name, |
| 13771 | /* .get_device_count = */ ggml_backend_vk_reg_get_device_count, |
| 13772 | /* .get_device = */ ggml_backend_vk_reg_get_device, |
| 13773 | /* .get_proc_address = */ NULL, |
| 13774 | }; |
| 13775 | |
| 13776 | ggml_backend_reg_t ggml_backend_vk_reg() { |
| 13777 | static ggml_backend_reg reg = { |
| 13778 | /* .api_version = */ GGML_BACKEND_API_VERSION, |
| 13779 | /* .iface = */ ggml_backend_vk_reg_i, |
| 13780 | /* .context = */ nullptr, |
| 13781 | }; |
| 13782 | try { |
| 13783 | ggml_vk_instance_init(); |
| 13784 | return ® |
| 13785 | } catch (const vk::SystemError& e) { |
| 13786 | VK_LOG_DEBUG("ggml_backend_vk_reg() -> Error: System error: " << e.what()); |
| 13787 | return nullptr; |
| 13788 | } catch (const std::exception &e) { |
| 13789 | VK_LOG_DEBUG("ggml_backend_vk_reg() -> Error: " << e.what()); |
| 13790 | return nullptr; |
| 13791 | } catch (...) { |
| 13792 | VK_LOG_DEBUG("ggml_backend_vk_reg() -> Error: unknown exception during Vulkan init" ); |
| 13793 | return nullptr; |
| 13794 | } |
| 13795 | } |
| 13796 | |
| 13797 | // Extension availability |
| 13798 | static bool ggml_vk_instance_validation_ext_available() { |
| 13799 | #ifdef GGML_VULKAN_VALIDATE |
| 13800 | // Check if validation layer provides the extension |
| 13801 | const std::string layer_name = "VK_LAYER_KHRONOS_validation" ; |
| 13802 | for (const auto& layer : vk::enumerateInstanceLayerProperties()) { |
| 13803 | if (layer_name == layer.layerName.data()) { |
| 13804 | for (const auto& ext : vk::enumerateInstanceExtensionProperties(layer_name)) { |
| 13805 | if (strcmp("VK_EXT_validation_features" , ext.extensionName.data()) == 0) { |
| 13806 | return true; |
| 13807 | } |
| 13808 | } |
| 13809 | } |
| 13810 | } |
| 13811 | |
| 13812 | std::cerr << "ggml_vulkan: WARNING: Validation layer or layer extension VK_EXT_validation_features not found." << std::endl; |
| 13813 | #endif |
| 13814 | return false; |
| 13815 | } |
| 13816 | static bool ggml_vk_instance_portability_enumeration_ext_available(const std::vector<vk::ExtensionProperties>& instance_extensions) { |
| 13817 | #ifdef __APPLE__ |
| 13818 | // Check for portability enumeration extension for MoltenVK support |
| 13819 | for (const auto& properties : instance_extensions) { |
| 13820 | if (strcmp("VK_KHR_portability_enumeration" , properties.extensionName) == 0) { |
| 13821 | return true; |
| 13822 | } |
| 13823 | } |
| 13824 | std::cerr << "ggml_vulkan: WARNING: Instance extension VK_KHR_portability_enumeration not found." << std::endl; |
| 13825 | #endif |
| 13826 | return false; |
| 13827 | |
| 13828 | UNUSED(instance_extensions); |
| 13829 | } |
| 13830 | |
| 13831 | // Extension availability |
| 13832 | static bool ggml_vk_instance_debug_utils_ext_available( |
| 13833 | const std::vector<vk::ExtensionProperties> & instance_extensions) { |
| 13834 | // Check for portability enumeration extension for MoltenVK support |
| 13835 | for (const auto & properties : instance_extensions) { |
| 13836 | if (strcmp(s1: "VK_EXT_debug_utils" , s2: properties.extensionName) == 0) { |
| 13837 | return true; |
| 13838 | } |
| 13839 | } |
| 13840 | |
| 13841 | std::cerr << "ggml_vulkan: WARNING: Instance extension VK_EXT_debug_utils not found." << std::endl; |
| 13842 | return false; |
| 13843 | |
| 13844 | UNUSED(instance_extensions); |
| 13845 | } |
| 13846 | |
| 13847 | static bool ggml_vk_device_is_supported(const vk::PhysicalDevice & vkdev) { |
| 13848 | VkPhysicalDeviceFeatures2 device_features2; |
| 13849 | device_features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2; |
| 13850 | |
| 13851 | VkPhysicalDeviceVulkan11Features vk11_features; |
| 13852 | vk11_features.pNext = nullptr; |
| 13853 | vk11_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_1_FEATURES; |
| 13854 | device_features2.pNext = &vk11_features; |
| 13855 | |
| 13856 | vkGetPhysicalDeviceFeatures2(physicalDevice: vkdev, pFeatures: &device_features2); |
| 13857 | |
| 13858 | return vk11_features.storageBuffer16BitAccess; |
| 13859 | } |
| 13860 | |
| 13861 | static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props, vk_device_architecture arch) { |
| 13862 | switch (props.vendorID) { |
| 13863 | case VK_VENDOR_ID_INTEL: |
| 13864 | // Only allowing Xe2 GPU at the moment since Xe2 GPU can gain significant performance boost, |
| 13865 | // while some older hardware (ex. Arc A770) has performance regressions |
| 13866 | return arch == vk_device_architecture::INTEL_XE2; |
| 13867 | case VK_VENDOR_ID_AMD: |
| 13868 | if (driver_props.driverID == vk::DriverId::eAmdProprietary || driver_props.driverID == vk::DriverId::eAmdOpenSource) { |
| 13869 | // Workaround for AMD proprietary driver reporting support on all GPUs |
| 13870 | return arch == vk_device_architecture::AMD_RDNA3; |
| 13871 | } |
| 13872 | return true; |
| 13873 | default: |
| 13874 | return true; |
| 13875 | } |
| 13876 | } |
| 13877 | |
| 13878 | // checks |
| 13879 | |
| 13880 | #ifdef GGML_VULKAN_CHECK_RESULTS |
| 13881 | static void ggml_vk_print_graph_origin(const ggml_tensor * tensor, std::vector<const ggml_tensor *>& done, int level = 0) { |
| 13882 | if (std::find(done.begin(), done.end(), tensor) != done.end() || level > 10) { |
| 13883 | return; |
| 13884 | } |
| 13885 | for (int j = 0; j < level; j++) { |
| 13886 | std::cerr << " " ; |
| 13887 | } |
| 13888 | std::cerr << ggml_op_name(tensor->op) << " gpu=" << (tensor->extra != nullptr) << std::endl; |
| 13889 | |
| 13890 | done.push_back(tensor); |
| 13891 | |
| 13892 | for (int i = 0; i < GGML_MAX_SRC; i++) { |
| 13893 | if (tensor->src[i] != nullptr) { |
| 13894 | ggml_vk_print_graph_origin(tensor->src[i], done, level + 1); |
| 13895 | } |
| 13896 | } |
| 13897 | } |
| 13898 | |
| 13899 | static void ggml_vk_print_tensor_area(const ggml_tensor * tensor, const void * data, int i0, int i1, int i2, int i3) { |
| 13900 | if (tensor->type != GGML_TYPE_F32 && tensor->type != GGML_TYPE_F16 && tensor->type != GGML_TYPE_I32) { |
| 13901 | return; |
| 13902 | } |
| 13903 | i0 = std::max(i0, 5); |
| 13904 | i1 = std::max(i1, 5); |
| 13905 | i2 = std::max(i2, 0); |
| 13906 | i3 = std::max(i3, 0); |
| 13907 | fprintf(stderr, " " ); |
| 13908 | for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) { |
| 13909 | fprintf(stderr, "%7d " , idx1); |
| 13910 | } |
| 13911 | fprintf(stderr, "\n" ); |
| 13912 | for (int idx0 = i0 - 5; idx0 < i0 + 5; idx0++) { |
| 13913 | fprintf(stderr, "%7d: " , idx0); |
| 13914 | for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) { |
| 13915 | if (idx0 >= 0 && idx0 < tensor->ne[0] && idx1 >= 0 && idx1 < tensor->ne[1] && i2 >= 0 && i2 < tensor->ne[2] && i3 >= 0 && i3 < tensor->ne[3]) { |
| 13916 | float val; |
| 13917 | if (tensor->type == GGML_TYPE_F32) { |
| 13918 | val = *(const float *) ((const char *) data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0]); |
| 13919 | } else if (tensor->type == GGML_TYPE_F16) { |
| 13920 | val = ggml_fp16_to_fp32(*(const ggml_fp16_t *) ((const char *) data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0])); |
| 13921 | } else if (tensor->type == GGML_TYPE_I32) { |
| 13922 | val = *(const int32_t *) ((const char *) data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0]); |
| 13923 | } else { |
| 13924 | GGML_ABORT("fatal error" ); |
| 13925 | } |
| 13926 | fprintf(stderr, "% 7.2f " , val); |
| 13927 | } else { |
| 13928 | fprintf(stderr, " " ); |
| 13929 | } |
| 13930 | } |
| 13931 | fprintf(stderr, "\n" ); |
| 13932 | } |
| 13933 | } |
| 13934 | |
| 13935 | static void ggml_vk_print_tensor(const ggml_tensor * tensor, const char * name) { |
| 13936 | void * tensor_data = tensor->data; |
| 13937 | |
| 13938 | const bool is_gpu = tensor->buffer != nullptr && ggml_backend_buffer_is_vk(tensor->buffer); |
| 13939 | |
| 13940 | if (is_gpu) { |
| 13941 | const size_t tensor_size = ggml_nbytes(tensor); |
| 13942 | tensor_data = malloc(tensor_size); |
| 13943 | |
| 13944 | ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context; |
| 13945 | |
| 13946 | vk_buffer buffer_gpu = buf_ctx->dev_buffer; |
| 13947 | ggml_vk_buffer_read(buffer_gpu, vk_tensor_offset(tensor) + tensor->view_offs, tensor_data, tensor_size); |
| 13948 | } |
| 13949 | |
| 13950 | std::cerr << "TENSOR CHECK " << name << " (" << tensor->name << "): " << ggml_op_name(tensor->op) << std::endl; |
| 13951 | std::cerr << "tensor=" << tensor << " tensor->type: " << ggml_type_name(tensor->type) << " ne0=" << tensor->ne[0] << " nb0=" << tensor->nb[0] << " ne1=" << tensor->ne[1] << " nb1=" << tensor->nb[1] << " ne2=" << tensor->ne[2] << " nb2=" << tensor->nb[2] << " ne3=" << tensor->ne[3] << " nb3=" << tensor->nb[3] << std::endl; |
| 13952 | if (tensor->src[0] != nullptr) { |
| 13953 | std::cerr << "tensor->src[0]=" << tensor->src[0] << " name=" << tensor->src[0]->name << " op=" << ggml_op_name(tensor->src[0]->op) << " type=" << ggml_type_name(tensor->src[0]->type) << " ne0=" << tensor->src[0]->ne[0] << " nb0=" << tensor->src[0]->nb[0] << " ne1=" << tensor->src[0]->ne[1] << " nb1=" << tensor->src[0]->nb[1] << " ne2=" << tensor->src[0]->ne[2] << " nb2=" << tensor->src[0]->nb[2] << " ne3=" << tensor->src[0]->ne[3] << " nb3=" << tensor->src[0]->nb[3] << std::endl; |
| 13954 | } |
| 13955 | if (tensor->src[1] != nullptr) { |
| 13956 | std::cerr << "tensor->src[1]=" << tensor->src[1] << " name=" << tensor->src[1]->name << " op=" << ggml_op_name(tensor->src[1]->op) << " type=" << ggml_type_name(tensor->src[1]->type) << " ne0=" << tensor->src[1]->ne[0] << " nb0=" << tensor->src[1]->nb[0] << " ne1=" << tensor->src[1]->ne[1] << " nb1=" << tensor->src[1]->nb[1] << " ne2=" << tensor->src[1]->ne[2] << " nb2=" << tensor->src[1]->nb[2] << " ne3=" << tensor->src[1]->ne[3] << " nb3=" << tensor->src[1]->nb[3] << std::endl; |
| 13957 | } |
| 13958 | std::cerr << std::endl << "Result:" << std::endl; |
| 13959 | ggml_vk_print_tensor_area(tensor, tensor_data, 5, 5, 0, 0); |
| 13960 | std::cerr << std::endl; |
| 13961 | std::vector<const ggml_tensor *> done; |
| 13962 | ggml_vk_print_graph_origin(tensor, done); |
| 13963 | |
| 13964 | if (is_gpu) { |
| 13965 | free(tensor_data); |
| 13966 | } |
| 13967 | } |
| 13968 | |
| 13969 | void * comp_result; |
| 13970 | size_t comp_size; |
| 13971 | size_t comp_nb[GGML_MAX_DIMS]; |
| 13972 | size_t check_counter = 0; |
| 13973 | static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx) { |
| 13974 | ggml_tensor * tensor = cgraph->nodes[tensor_idx + ctx->num_additional_fused_ops]; |
| 13975 | if (tensor->op == GGML_OP_TRANSPOSE || tensor->op == GGML_OP_SET_ROWS) { |
| 13976 | return; |
| 13977 | } |
| 13978 | |
| 13979 | check_counter++; |
| 13980 | if (!(vk_output_tensor > 0 && vk_output_tensor == check_counter) && check_counter <= vk_skip_checks) { |
| 13981 | return; |
| 13982 | } |
| 13983 | |
| 13984 | VK_LOG_DEBUG("ggml_vk_check_results_0(" << tensor->name << ")" ); |
| 13985 | |
| 13986 | struct ggml_init_params iparams = { |
| 13987 | /*.mem_size =*/ 2ul*1024ul*1024ul*1024ul, |
| 13988 | /*.mem_buffer =*/ NULL, |
| 13989 | /*.no_alloc =*/ false, |
| 13990 | }; |
| 13991 | |
| 13992 | struct ggml_context * ggml_ctx = ggml_init(iparams); |
| 13993 | |
| 13994 | std::array<struct ggml_tensor *, GGML_MAX_SRC> src_clone = {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr}; |
| 13995 | const char * srci_name[GGML_MAX_SRC] = {"src0" , "src1" , "src2" , "src3" , "src4" , "src5" , "src6" , "src7" , "src8" , "src9" }; |
| 13996 | |
| 13997 | std::map<ggml_tensor *, ggml_tensor *> cloned_tensors; |
| 13998 | std::vector<void *> cloned_mallocs; |
| 13999 | |
| 14000 | struct ggml_tensor * tensor_clone = nullptr; |
| 14001 | |
| 14002 | for (int f = 0; f < ctx->num_additional_fused_ops + 1; ++f) { |
| 14003 | tensor = cgraph->nodes[tensor_idx + f]; |
| 14004 | for (int i = 0; i < GGML_MAX_SRC; i++) { |
| 14005 | ggml_tensor * srci = tensor->src[i]; |
| 14006 | if (srci == nullptr) { |
| 14007 | continue; |
| 14008 | } |
| 14009 | // If a src tensor has been cloned, use that one |
| 14010 | auto it = cloned_tensors.find(srci); |
| 14011 | if (it != cloned_tensors.end()) { |
| 14012 | src_clone[i] = it->second; |
| 14013 | continue; |
| 14014 | } |
| 14015 | ggml_tensor * srci_clone = ggml_dup_tensor(ggml_ctx, srci); |
| 14016 | size_t srci_size = ggml_nbytes(srci); |
| 14017 | |
| 14018 | src_clone[i] = srci_clone; |
| 14019 | void *src_buffer = malloc(srci_size); |
| 14020 | cloned_mallocs.push_back(src_buffer); |
| 14021 | |
| 14022 | srci_clone->data = src_buffer; |
| 14023 | if (ggml_backend_buffer_is_host(srci->buffer)) { |
| 14024 | memcpy(srci_clone->data, srci->data, srci_size); |
| 14025 | memcpy(srci_clone->nb, srci->nb, sizeof(size_t) * GGML_MAX_DIMS); |
| 14026 | } else if (ggml_backend_buffer_is_vk(srci->buffer)) { |
| 14027 | ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)srci->buffer->context; |
| 14028 | vk_buffer& buffer_gpu = buf_ctx->dev_buffer; |
| 14029 | uint64_t offset = vk_tensor_offset(srci) + srci->view_offs; |
| 14030 | if (!ggml_is_contiguous(srci) && ggml_vk_dim01_contiguous(srci)) { |
| 14031 | for (int i3 = 0; i3 < srci->ne[3]; i3++) { |
| 14032 | for (int i2 = 0; i2 < srci->ne[2]; i2++) { |
| 14033 | const int idx = i3*srci->ne[2] + i2; |
| 14034 | ggml_vk_buffer_read(buffer_gpu, offset + idx * srci->nb[2], ((char *)srci_clone->data + idx * srci_clone->nb[2]), srci->ne[1] * srci->nb[1]); |
| 14035 | } |
| 14036 | } |
| 14037 | |
| 14038 | srci_clone->nb[0] = srci->nb[0]; |
| 14039 | srci_clone->nb[1] = srci->nb[1]; |
| 14040 | for (int i = 2; i < GGML_MAX_DIMS; i++) { |
| 14041 | srci_clone->nb[i] = srci_clone->nb[i - 1]*srci_clone->ne[i - 1]; |
| 14042 | } |
| 14043 | } else { |
| 14044 | if (offset + srci_size >= buffer_gpu->size) { |
| 14045 | srci_size = buffer_gpu->size - offset; |
| 14046 | } |
| 14047 | ggml_vk_buffer_read(buffer_gpu, offset, srci_clone->data, srci_size); |
| 14048 | memcpy(srci_clone->nb, srci->nb, sizeof(size_t) * GGML_MAX_DIMS); |
| 14049 | } |
| 14050 | } else { |
| 14051 | GGML_ABORT("fatal error" ); |
| 14052 | } |
| 14053 | |
| 14054 | if (vk_output_tensor > 0 && vk_output_tensor == check_counter) { |
| 14055 | ggml_vk_print_tensor(srci, srci_name[i]); |
| 14056 | } |
| 14057 | } |
| 14058 | |
| 14059 | if (tensor->op == GGML_OP_FLASH_ATTN_EXT) { |
| 14060 | const float * params = (const float *)tensor->op_params; |
| 14061 | tensor_clone = ggml_flash_attn_ext(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3], params[0], params[1], params[2]); |
| 14062 | if (src_clone[4]) { |
| 14063 | ggml_flash_attn_ext_add_sinks(tensor_clone, src_clone[4]); |
| 14064 | } |
| 14065 | } else if (tensor->op == GGML_OP_MUL_MAT) { |
| 14066 | tensor_clone = ggml_mul_mat(ggml_ctx, src_clone[0], src_clone[1]); |
| 14067 | } else if (tensor->op == GGML_OP_MUL_MAT_ID) { |
| 14068 | tensor_clone = ggml_mul_mat_id(ggml_ctx, src_clone[0], src_clone[1], src_clone[2]); |
| 14069 | } else if (tensor->op == GGML_OP_SUB) { |
| 14070 | tensor_clone = ggml_sub(ggml_ctx, src_clone[0], src_clone[1]); |
| 14071 | } else if (tensor->op == GGML_OP_MUL) { |
| 14072 | tensor_clone = ggml_mul(ggml_ctx, src_clone[0], src_clone[1]); |
| 14073 | } else if (tensor->op == GGML_OP_DIV) { |
| 14074 | tensor_clone = ggml_div(ggml_ctx, src_clone[0], src_clone[1]); |
| 14075 | } else if (tensor->op == GGML_OP_CONCAT) { |
| 14076 | tensor_clone = ggml_concat(ggml_ctx, src_clone[0], src_clone[1], *(int *)tensor->op_params); |
| 14077 | } else if (tensor->op == GGML_OP_UPSCALE) { |
| 14078 | tensor_clone = ggml_interpolate(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], (ggml_scale_mode) tensor->op_params[0]); |
| 14079 | } else if (tensor->op == GGML_OP_SCALE) { |
| 14080 | const float * params = (const float *)tensor->op_params; |
| 14081 | tensor_clone = ggml_scale_bias(ggml_ctx, src_clone[0], params[0], params[1]); |
| 14082 | } else if (tensor->op == GGML_OP_SQR) { |
| 14083 | tensor_clone = ggml_sqr(ggml_ctx, src_clone[0]); |
| 14084 | } else if (tensor->op == GGML_OP_SQRT) { |
| 14085 | tensor_clone = ggml_sqrt(ggml_ctx, src_clone[0]); |
| 14086 | } else if (tensor->op == GGML_OP_SIN) { |
| 14087 | tensor_clone = ggml_sin(ggml_ctx, src_clone[0]); |
| 14088 | } else if (tensor->op == GGML_OP_COS) { |
| 14089 | tensor_clone = ggml_cos(ggml_ctx, src_clone[0]); |
| 14090 | } else if (tensor->op == GGML_OP_CLAMP) { |
| 14091 | const float * params = (const float *)tensor->op_params; |
| 14092 | tensor_clone = ggml_clamp(ggml_ctx, src_clone[0], params[0], params[1]); |
| 14093 | } else if (tensor->op == GGML_OP_PAD) { |
| 14094 | tensor_clone = ggml_pad_ext(ggml_ctx, src_clone[0], tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3], |
| 14095 | tensor->op_params[4], tensor->op_params[5], tensor->op_params[6], tensor->op_params[7]); |
| 14096 | } else if (tensor->op == GGML_OP_REPEAT) { |
| 14097 | tensor_clone = ggml_repeat(ggml_ctx, src_clone[0], tensor); |
| 14098 | } else if (tensor->op == GGML_OP_REPEAT_BACK) { |
| 14099 | tensor_clone = ggml_repeat_back(ggml_ctx, src_clone[0], tensor); |
| 14100 | } else if (tensor->op == GGML_OP_ADD) { |
| 14101 | tensor_clone = ggml_add(ggml_ctx, src_clone[0], src_clone[1]); |
| 14102 | } else if (tensor->op == GGML_OP_ACC) { |
| 14103 | tensor_clone = ggml_acc(ggml_ctx, src_clone[0], src_clone[1], tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3]); |
| 14104 | } else if (tensor->op == GGML_OP_NORM) { |
| 14105 | tensor_clone = ggml_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params); |
| 14106 | } else if (tensor->op == GGML_OP_GROUP_NORM) { |
| 14107 | const float * float_params = (const float *)tensor->op_params; |
| 14108 | tensor_clone = ggml_group_norm(ggml_ctx, src_clone[0], tensor->op_params[0], float_params[1]); |
| 14109 | } else if (tensor->op == GGML_OP_RMS_NORM) { |
| 14110 | tensor_clone = ggml_rms_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params); |
| 14111 | } else if (tensor->op == GGML_OP_RMS_NORM_BACK) { |
| 14112 | const float eps = ((float *) tensor->op_params)[0]; |
| 14113 | tensor_clone = ggml_rms_norm_back(ggml_ctx, src_clone[0], src_clone[1], eps); |
| 14114 | } else if (tensor->op == GGML_OP_SILU_BACK) { |
| 14115 | tensor_clone = ggml_silu_back(ggml_ctx, src_clone[0], src_clone[1]); |
| 14116 | } else if (tensor->op == GGML_OP_L2_NORM) { |
| 14117 | const float eps = ((float *) tensor->op_params)[0]; |
| 14118 | tensor_clone = ggml_l2_norm(ggml_ctx, src_clone[0], eps); |
| 14119 | } else if (tensor->op == GGML_OP_SOFT_MAX) { |
| 14120 | if (tensor->src[1] != nullptr) { |
| 14121 | const float * params = (const float *)tensor->op_params; |
| 14122 | tensor_clone = ggml_soft_max_ext(ggml_ctx, src_clone[0], src_clone[1], params[0], params[1]); |
| 14123 | } else { |
| 14124 | tensor_clone = ggml_soft_max(ggml_ctx, src_clone[0]); |
| 14125 | } |
| 14126 | } else if (tensor->op == GGML_OP_SOFT_MAX_BACK) { |
| 14127 | tensor_clone = ggml_soft_max_ext_back(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]); |
| 14128 | } else if (tensor->op == GGML_OP_DIAG_MASK_INF) { |
| 14129 | tensor_clone = ggml_diag_mask_inf(ggml_ctx, src_clone[0], tensor->op_params[0]); |
| 14130 | } else if (tensor->op == GGML_OP_ROPE || tensor->op == GGML_OP_ROPE_BACK) { |
| 14131 | const int n_dims = ((int32_t *) tensor->op_params)[1]; |
| 14132 | const int mode = ((int32_t *) tensor->op_params)[2]; |
| 14133 | //const int n_ctx_ggml = ((int32_t *) tensor->op_params)[3]; |
| 14134 | const int n_ctx_orig_ggml = ((int32_t *) tensor->op_params)[4]; |
| 14135 | const float freq_base = ((float *) tensor->op_params)[5]; |
| 14136 | const float freq_scale = ((float *) tensor->op_params)[6]; |
| 14137 | const float ext_factor = ((float *) tensor->op_params)[7]; |
| 14138 | const float attn_factor = ((float *) tensor->op_params)[8]; |
| 14139 | const float beta_fast = ((float *) tensor->op_params)[9]; |
| 14140 | const float beta_slow = ((float *) tensor->op_params)[10]; |
| 14141 | if (mode & GGML_ROPE_TYPE_MROPE) { |
| 14142 | int32_t *sections = ((int32_t *) tensor->op_params) + 11; |
| 14143 | if (tensor->op == GGML_OP_ROPE) { |
| 14144 | tensor_clone = ggml_rope_multi(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, sections, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); |
| 14145 | } else { |
| 14146 | tensor_clone = ggml_rope_multi_back(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, sections, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); |
| 14147 | } |
| 14148 | } else { |
| 14149 | if (tensor->op == GGML_OP_ROPE) { |
| 14150 | tensor_clone = ggml_rope_ext(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); |
| 14151 | } else { |
| 14152 | tensor_clone = ggml_rope_ext_back(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); |
| 14153 | } |
| 14154 | } |
| 14155 | } else if (tensor->op == GGML_OP_UNARY) { |
| 14156 | switch (ggml_get_unary_op(tensor)) { |
| 14157 | case GGML_UNARY_OP_EXP: |
| 14158 | tensor_clone = ggml_exp(ggml_ctx, src_clone[0]); |
| 14159 | break; |
| 14160 | case GGML_UNARY_OP_SILU: |
| 14161 | tensor_clone = ggml_silu(ggml_ctx, src_clone[0]); |
| 14162 | break; |
| 14163 | case GGML_UNARY_OP_GELU: |
| 14164 | tensor_clone = ggml_gelu(ggml_ctx, src_clone[0]); |
| 14165 | break; |
| 14166 | case GGML_UNARY_OP_GELU_ERF: |
| 14167 | tensor_clone = ggml_gelu_erf(ggml_ctx, src_clone[0]); |
| 14168 | break; |
| 14169 | case GGML_UNARY_OP_GELU_QUICK: |
| 14170 | tensor_clone = ggml_gelu_quick(ggml_ctx, src_clone[0]); |
| 14171 | break; |
| 14172 | case GGML_UNARY_OP_RELU: |
| 14173 | tensor_clone = ggml_relu(ggml_ctx, src_clone[0]); |
| 14174 | break; |
| 14175 | case GGML_UNARY_OP_TANH: |
| 14176 | tensor_clone = ggml_tanh(ggml_ctx, src_clone[0]); |
| 14177 | break; |
| 14178 | case GGML_UNARY_OP_SIGMOID: |
| 14179 | tensor_clone = ggml_sigmoid(ggml_ctx, src_clone[0]); |
| 14180 | break; |
| 14181 | case GGML_UNARY_OP_HARDSIGMOID: |
| 14182 | tensor_clone = ggml_hardsigmoid(ggml_ctx, src_clone[0]); |
| 14183 | break; |
| 14184 | case GGML_UNARY_OP_HARDSWISH: |
| 14185 | tensor_clone = ggml_hardswish(ggml_ctx, src_clone[0]); |
| 14186 | break; |
| 14187 | default: |
| 14188 | std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl; |
| 14189 | GGML_ABORT("fatal error" ); |
| 14190 | } |
| 14191 | } else if (tensor->op == GGML_OP_GLU) { |
| 14192 | if (src_clone[1] == nullptr) { |
| 14193 | tensor_clone = ggml_glu(ggml_ctx, src_clone[0], (ggml_glu_op) tensor->op_params[0], tensor->op_params[1]); |
| 14194 | } else { |
| 14195 | tensor_clone = ggml_glu_split(ggml_ctx, src_clone[0], src_clone[1], (ggml_glu_op) tensor->op_params[0]); |
| 14196 | } |
| 14197 | ggml_set_op_params_i32(tensor_clone, 2, ggml_get_op_params_i32(tensor, 2)); |
| 14198 | ggml_set_op_params_i32(tensor_clone, 3, ggml_get_op_params_i32(tensor, 3)); |
| 14199 | } else if (tensor->op == GGML_OP_CPY || tensor->op == GGML_OP_DUP) { |
| 14200 | if (tensor->src[1] == nullptr) { |
| 14201 | tensor_clone = ggml_dup(ggml_ctx, src_clone[0]); |
| 14202 | tensor_clone->type = tensor->type; |
| 14203 | } else { |
| 14204 | tensor_clone = ggml_cpy(ggml_ctx, src_clone[0], src_clone[1]); |
| 14205 | } |
| 14206 | } else if (tensor->op == GGML_OP_CONT) { |
| 14207 | tensor_clone = ggml_cont_4d(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]); |
| 14208 | } else if (tensor->op == GGML_OP_RESHAPE) { |
| 14209 | tensor_clone = ggml_reshape_4d(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]); |
| 14210 | } else if (tensor->op == GGML_OP_VIEW) { |
| 14211 | tensor_clone = ggml_view_4d(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], tensor->nb[1], tensor->nb[2], tensor->nb[3], ((int32_t *) tensor->op_params)[0]); |
| 14212 | } else if (tensor->op == GGML_OP_PERMUTE) { |
| 14213 | int32_t * params = (int32_t *)tensor->op_params; |
| 14214 | tensor_clone = ggml_permute(ggml_ctx, src_clone[0], params[0], params[1], params[2], params[3]); |
| 14215 | } else if (tensor->op == GGML_OP_TRANSPOSE) { |
| 14216 | tensor_clone = ggml_transpose(ggml_ctx, src_clone[0]); |
| 14217 | } else if (tensor->op == GGML_OP_GET_ROWS) { |
| 14218 | tensor_clone = ggml_get_rows(ggml_ctx, src_clone[0], src_clone[1]); |
| 14219 | } else if (tensor->op == GGML_OP_ARGSORT) { |
| 14220 | tensor_clone = ggml_argsort(ggml_ctx, src_clone[0], (ggml_sort_order) *(int *)tensor->op_params); |
| 14221 | } else if (tensor->op == GGML_OP_SUM) { |
| 14222 | tensor_clone = ggml_sum(ggml_ctx, src_clone[0]); |
| 14223 | } else if (tensor->op == GGML_OP_SUM_ROWS) { |
| 14224 | tensor_clone = ggml_sum_rows(ggml_ctx, src_clone[0]); |
| 14225 | } else if (tensor->op == GGML_OP_MEAN) { |
| 14226 | tensor_clone = ggml_mean(ggml_ctx, src_clone[0]); |
| 14227 | } else if (tensor->op == GGML_OP_ARGMAX) { |
| 14228 | tensor_clone = ggml_argmax(ggml_ctx, src_clone[0]); |
| 14229 | } else if (tensor->op == GGML_OP_COUNT_EQUAL) { |
| 14230 | tensor_clone = ggml_count_equal(ggml_ctx, src_clone[0], src_clone[1]); |
| 14231 | } else if (tensor->op == GGML_OP_IM2COL) { |
| 14232 | const int32_t s0 = tensor->op_params[0]; |
| 14233 | const int32_t s1 = tensor->op_params[1]; |
| 14234 | const int32_t p0 = tensor->op_params[2]; |
| 14235 | const int32_t p1 = tensor->op_params[3]; |
| 14236 | const int32_t d0 = tensor->op_params[4]; |
| 14237 | const int32_t d1 = tensor->op_params[5]; |
| 14238 | |
| 14239 | const bool is_2D = tensor->op_params[6] == 1; |
| 14240 | tensor_clone = ggml_im2col(ggml_ctx, src_clone[0], src_clone[1], s0, s1, p0, p1, d0, d1, is_2D, tensor->type); |
| 14241 | } else if (tensor->op == GGML_OP_IM2COL_3D) { |
| 14242 | const int32_t s0 = tensor->op_params[0]; |
| 14243 | const int32_t s1 = tensor->op_params[1]; |
| 14244 | const int32_t s2 = tensor->op_params[2]; |
| 14245 | const int32_t p0 = tensor->op_params[3]; |
| 14246 | const int32_t p1 = tensor->op_params[4]; |
| 14247 | const int32_t p2 = tensor->op_params[5]; |
| 14248 | const int32_t d0 = tensor->op_params[6]; |
| 14249 | const int32_t d1 = tensor->op_params[7]; |
| 14250 | const int32_t d2 = tensor->op_params[8]; |
| 14251 | const int32_t IC = tensor->op_params[9]; |
| 14252 | |
| 14253 | tensor_clone = ggml_im2col_3d(ggml_ctx, src_clone[0], src_clone[1], IC, s0, s1, s2, p0, p1, p2, d0, d1, d2, tensor->type); |
| 14254 | } else if (tensor->op == GGML_OP_TIMESTEP_EMBEDDING) { |
| 14255 | const int32_t dim = tensor->op_params[0]; |
| 14256 | const int32_t max_period = tensor->op_params[1]; |
| 14257 | tensor_clone = ggml_timestep_embedding(ggml_ctx, src_clone[0], dim, max_period); |
| 14258 | } else if (tensor->op == GGML_OP_CONV_TRANSPOSE_1D){ |
| 14259 | const int32_t s0 = tensor->op_params[0]; |
| 14260 | const int32_t p0 = tensor->op_params[1]; |
| 14261 | const int32_t d0 = tensor->op_params[2]; |
| 14262 | tensor_clone = ggml_conv_transpose_1d(ggml_ctx, src_clone[0], src_clone[1], s0, p0, d0); |
| 14263 | } else if (tensor->op == GGML_OP_POOL_2D) { |
| 14264 | enum ggml_op_pool op = static_cast<ggml_op_pool>(tensor->op_params[0]); |
| 14265 | const int32_t k0 = tensor->op_params[1]; |
| 14266 | const int32_t k1 = tensor->op_params[2]; |
| 14267 | const int32_t s0 = tensor->op_params[3]; |
| 14268 | const int32_t s1 = tensor->op_params[4]; |
| 14269 | const int32_t p0 = tensor->op_params[5]; |
| 14270 | const int32_t p1 = tensor->op_params[6]; |
| 14271 | |
| 14272 | tensor_clone = ggml_pool_2d(ggml_ctx, src_clone[0], op, k0, k1, s0, s1, p0, p1); |
| 14273 | } else if (tensor->op == GGML_OP_CONV_2D) { |
| 14274 | const int32_t s0 = tensor->op_params[0]; |
| 14275 | const int32_t s1 = tensor->op_params[1]; |
| 14276 | const int32_t p0 = tensor->op_params[2]; |
| 14277 | const int32_t p1 = tensor->op_params[3]; |
| 14278 | const int32_t d0 = tensor->op_params[4]; |
| 14279 | const int32_t d1 = tensor->op_params[5]; |
| 14280 | tensor_clone = ggml_conv_2d(ggml_ctx, src_clone[0], src_clone[1], s0, s1, p0, p1, d0, d1); |
| 14281 | } else if (tensor->op == GGML_OP_CONV_2D_DW) { |
| 14282 | const int32_t s0 = tensor->op_params[0]; |
| 14283 | const int32_t s1 = tensor->op_params[1]; |
| 14284 | const int32_t p0 = tensor->op_params[2]; |
| 14285 | const int32_t p1 = tensor->op_params[3]; |
| 14286 | const int32_t d0 = tensor->op_params[4]; |
| 14287 | const int32_t d1 = tensor->op_params[5]; |
| 14288 | tensor_clone = ggml_conv_2d_dw_direct(ggml_ctx, src_clone[0], src_clone[1], s0, s1, p0, p1, d0, d1); |
| 14289 | } else if (tensor->op == GGML_OP_CONV_TRANSPOSE_2D) { |
| 14290 | const int32_t s = tensor->op_params[0]; |
| 14291 | tensor_clone = ggml_conv_transpose_2d_p0(ggml_ctx, src_clone[0], src_clone[1], s); |
| 14292 | } else if (tensor->op == GGML_OP_LEAKY_RELU) { |
| 14293 | const float * op_params = (const float *)tensor->op_params; |
| 14294 | tensor_clone = ggml_leaky_relu(ggml_ctx, src_clone[0], op_params[0], false); |
| 14295 | } else if (tensor->op == GGML_OP_RWKV_WKV6) { |
| 14296 | tensor_clone = ggml_rwkv_wkv6(ggml_ctx, src_clone[0], src_clone[1], |
| 14297 | src_clone[2], src_clone[3], src_clone[4], src_clone[5]); |
| 14298 | } else if (tensor->op == GGML_OP_RWKV_WKV7) { |
| 14299 | tensor_clone = ggml_rwkv_wkv7(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3], |
| 14300 | src_clone[4], src_clone[5], src_clone[6]); |
| 14301 | } else if (tensor->op == GGML_OP_OPT_STEP_ADAMW) { |
| 14302 | src_clone[0]->flags = tensor->src[0]->flags; |
| 14303 | tensor_clone = ggml_opt_step_adamw(ggml_ctx, src_clone[0], src_clone[1], |
| 14304 | src_clone[2], src_clone[3], src_clone[4]); |
| 14305 | } else if (tensor->op == GGML_OP_OPT_STEP_SGD) { |
| 14306 | src_clone[0]->flags = tensor->src[0]->flags; |
| 14307 | tensor_clone = ggml_opt_step_sgd(ggml_ctx, src_clone[0], src_clone[1], |
| 14308 | src_clone[2]); |
| 14309 | } else if (tensor->op == GGML_OP_ADD_ID) { |
| 14310 | tensor_clone = ggml_add_id(ggml_ctx, src_clone[0], src_clone[1], src_clone[2]); |
| 14311 | } else if (tensor->op == GGML_OP_SSM_SCAN) { |
| 14312 | tensor_clone = ggml_ssm_scan(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], |
| 14313 | src_clone[3], src_clone[4], src_clone[5], src_clone[6]); |
| 14314 | } else if (tensor->op == GGML_OP_SSM_CONV) { |
| 14315 | tensor_clone = ggml_ssm_conv(ggml_ctx, src_clone[0], src_clone[1]); |
| 14316 | } else if (tensor->op == GGML_OP_ROLL) { |
| 14317 | const int32_t s0 = tensor->op_params[0]; |
| 14318 | const int32_t s1 = tensor->op_params[1]; |
| 14319 | const int32_t s2 = tensor->op_params[2]; |
| 14320 | const int32_t s3 = tensor->op_params[3]; |
| 14321 | tensor_clone = ggml_roll(ggml_ctx, src_clone[0], s0, s1, s2, s3); |
| 14322 | } |
| 14323 | else { |
| 14324 | std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl; |
| 14325 | GGML_ABORT("fatal error" ); |
| 14326 | } |
| 14327 | cloned_tensors[tensor] = tensor_clone; |
| 14328 | } |
| 14329 | |
| 14330 | ggml_cgraph * cgraph_cpu = ggml_new_graph(ggml_ctx); |
| 14331 | ggml_build_forward_expand(cgraph_cpu, tensor_clone); |
| 14332 | |
| 14333 | ggml_graph_compute_with_ctx(ggml_ctx, cgraph_cpu, 8); |
| 14334 | |
| 14335 | if (vk_output_tensor > 0 && vk_output_tensor == check_counter) { |
| 14336 | ggml_vk_print_tensor(tensor_clone, "tensor_clone" ); |
| 14337 | } |
| 14338 | |
| 14339 | comp_size = ggml_nbytes(tensor_clone); |
| 14340 | |
| 14341 | comp_result = malloc(comp_size); |
| 14342 | memcpy(comp_result, tensor_clone->data, comp_size); |
| 14343 | memcpy(comp_nb, tensor_clone->nb, sizeof(size_t) * GGML_MAX_DIMS); |
| 14344 | |
| 14345 | for (auto m : cloned_mallocs) { |
| 14346 | free(m); |
| 14347 | } |
| 14348 | |
| 14349 | ggml_free(ggml_ctx); |
| 14350 | |
| 14351 | VK_LOG_DEBUG("END ggml_vk_check_results_0(" << tensor->name << ")" ); |
| 14352 | } |
| 14353 | |
| 14354 | static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx) { |
| 14355 | ggml_tensor * tensor = cgraph->nodes[tensor_idx + ctx->num_additional_fused_ops]; |
| 14356 | if (tensor->op == GGML_OP_TRANSPOSE || tensor->op == GGML_OP_SET_ROWS) { |
| 14357 | return; |
| 14358 | } |
| 14359 | |
| 14360 | if (!(vk_output_tensor > 0 && vk_output_tensor == check_counter) && check_counter <= vk_skip_checks) { |
| 14361 | return; |
| 14362 | } |
| 14363 | |
| 14364 | VK_LOG_DEBUG("ggml_vk_check_results_1(" << tensor->name << ")" ); |
| 14365 | |
| 14366 | ggml_tensor * src0 = tensor->src[0]; |
| 14367 | ggml_tensor * src1 = tensor->src[1]; |
| 14368 | ggml_tensor * src2 = tensor->src[2]; |
| 14369 | ggml_tensor * src3 = tensor->src[3]; |
| 14370 | |
| 14371 | void * tensor_data = tensor->data; |
| 14372 | |
| 14373 | if (ggml_backend_buffer_is_vk(tensor->buffer)) { |
| 14374 | size_t tensor_size = ggml_nbytes(tensor); |
| 14375 | tensor_data = malloc(tensor_size); |
| 14376 | |
| 14377 | ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context; |
| 14378 | |
| 14379 | vk_buffer& buffer_gpu = buf_ctx->dev_buffer; |
| 14380 | uint64_t offset = vk_tensor_offset(tensor) + tensor->view_offs; |
| 14381 | if (offset + tensor_size >= buffer_gpu->size) { |
| 14382 | tensor_size = buffer_gpu->size - offset; |
| 14383 | } |
| 14384 | |
| 14385 | ggml_vk_buffer_read(buffer_gpu, offset, tensor_data, tensor_size); |
| 14386 | } |
| 14387 | |
| 14388 | float first_error_result = -1.0f; |
| 14389 | float first_error_correct = -1.0f; |
| 14390 | std::array<int, 4> first_error = { -1, -1, -1, -1 }; |
| 14391 | double avg_err = 0.0; |
| 14392 | size_t counter = 0; |
| 14393 | |
| 14394 | for (int i3 = 0; i3 < tensor->ne[3]; i3++) { |
| 14395 | for (int i2 = 0; i2 < tensor->ne[2]; i2++) { |
| 14396 | for (int i1 = 0; i1 < tensor->ne[1]; i1++) { |
| 14397 | for (int i0 = 0; i0 < tensor->ne[0]; i0++) { |
| 14398 | const bool buffer_size_fit = i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0] < comp_size; |
| 14399 | float correct = 0.0f; |
| 14400 | float result = 0.0f; |
| 14401 | |
| 14402 | if (buffer_size_fit) { |
| 14403 | if (tensor->type == GGML_TYPE_F32) { |
| 14404 | correct = *(float *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]); |
| 14405 | result = *(float *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]); |
| 14406 | } else if (tensor->type == GGML_TYPE_F16) { |
| 14407 | correct = ggml_fp16_to_fp32(*(ggml_fp16_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0])); |
| 14408 | result = ggml_fp16_to_fp32(*(ggml_fp16_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0])); |
| 14409 | } else if (tensor->type == GGML_TYPE_BF16) { |
| 14410 | correct = ggml_bf16_to_fp32(*(ggml_bf16_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0])); |
| 14411 | result = ggml_bf16_to_fp32(*(ggml_bf16_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0])); |
| 14412 | } else if (tensor->type == GGML_TYPE_I32) { |
| 14413 | correct = *(int32_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]); |
| 14414 | result = *(int32_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]); |
| 14415 | } else if (tensor->type == GGML_TYPE_I64) { |
| 14416 | correct = *(int64_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]); |
| 14417 | result = *(int64_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]); |
| 14418 | } else { |
| 14419 | std::cerr << "Results check not implemented for type " << ggml_type_name(tensor->type) << std::endl; |
| 14420 | } |
| 14421 | } else { |
| 14422 | std::cerr << "Missing debug code for type " << ggml_type_name(tensor->type) << std::endl; |
| 14423 | GGML_ABORT("fatal error" ); |
| 14424 | } |
| 14425 | |
| 14426 | if ((std::isnan(correct) != std::isnan(result)) || (std::isinf(correct) != std::isinf(result)) || !buffer_size_fit) { |
| 14427 | std::cerr << "ERROR: Invalid value in " << ggml_op_name(tensor->op) << " i3=" << i3 << " i2=" << i2 << " i1=" << i1 << " i0=" << i0 << " result=" << result << " correct=" << correct << " avg_err=" << (avg_err / counter) << std::endl; |
| 14428 | std::cerr << "tensor=" << tensor << " tensor->name=" << tensor->name << " tensor->type: " << ggml_type_name(tensor->type) << " ne0=" << tensor->ne[0] << " nb0=" << tensor->nb[0] << " ne1=" << tensor->ne[1] << " nb1=" << tensor->nb[1] << " ne2=" << tensor->ne[2] << " nb2=" << tensor->nb[2] << " ne3=" << tensor->ne[3] << " nb3=" << tensor->nb[3] << " offset=" << tensor->view_offs << std::endl; |
| 14429 | if (src0 != nullptr) { |
| 14430 | std::cerr << "src0=" << src0 << " src0->name=" << src0->name << " op=" << ggml_op_name(src0->op) << " type=" << ggml_type_name(src0->type) << " ne0=" << src0->ne[0] << " nb0=" << src0->nb[0] << " ne1=" << src0->ne[1] << " nb1=" << src0->nb[1] << " ne2=" << src0->ne[2] << " nb2=" << src0->nb[2] << " ne3=" << src0->ne[3] << " nb3=" << src0->nb[3] << " offset=" << src0->view_offs << std::endl; |
| 14431 | } |
| 14432 | if (src1 != nullptr) { |
| 14433 | std::cerr << "src1=" << src1 << " src1->name=" << src1->name << " op=" << ggml_op_name(src1->op) << " type=" << ggml_type_name(src1->type) << " ne0=" << src1->ne[0] << " nb0=" << src1->nb[0] << " ne1=" << src1->ne[1] << " nb1=" << src1->nb[1] << " ne2=" << src1->ne[2] << " nb2=" << src1->nb[2] << " ne3=" << src1->ne[3] << " nb3=" << src1->nb[3] << " offset=" << src1->view_offs << std::endl; |
| 14434 | } |
| 14435 | if (src2 != nullptr) { |
| 14436 | std::cerr << "src2=" << src2 << " src2->name=" << src2->name << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl; |
| 14437 | } |
| 14438 | if (src3 != nullptr) { |
| 14439 | std::cerr << "src3=" << src3 << " src3->name=" << src3->name << " op=" << ggml_op_name(src3->op) << " type=" << ggml_type_name(src3->type) << " ne0=" << src3->ne[0] << " nb0=" << src3->nb[0] << " ne1=" << src3->ne[1] << " nb1=" << src3->nb[1] << " ne2=" << src3->ne[2] << " nb2=" << src3->nb[2] << " ne3=" << src3->ne[3] << " nb3=" << src3->nb[3] << " offset=" << src3->view_offs << std::endl; |
| 14440 | } |
| 14441 | std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl; |
| 14442 | std::cerr << std::endl << "Result:" << std::endl; |
| 14443 | ggml_vk_print_tensor_area(tensor, tensor_data, i0, i1, i2, i3); |
| 14444 | std::cerr << std::endl << "Correct:" << std::endl; |
| 14445 | ggml_vk_print_tensor_area(tensor, comp_result, i0, i1, i2, i3); |
| 14446 | std::cerr << std::endl; |
| 14447 | std::vector<const ggml_tensor *> done; |
| 14448 | ggml_vk_print_graph_origin(tensor, done); |
| 14449 | GGML_ABORT("fatal error" ); |
| 14450 | } |
| 14451 | const double denom = std::fabs(correct) > 1.0f ? (std::fabs(correct) > 1e-8 ? std::fabs(correct) : 1e-8) : 1.0f; |
| 14452 | if (first_error[0] == -1 && std::fabs(correct - result) / denom > 0.5) { |
| 14453 | first_error[0] = i0; |
| 14454 | first_error[1] = i1; |
| 14455 | first_error[2] = i2; |
| 14456 | first_error[3] = i3; |
| 14457 | first_error_result = result; |
| 14458 | first_error_correct = correct; |
| 14459 | } |
| 14460 | |
| 14461 | // Special case, value is infinite, avoid NaN result in avg_err |
| 14462 | // NaN also appears in results, if both are nan error is 0 |
| 14463 | if (!std::isinf(correct) && !std::isinf(result) && !std::isnan(correct) && !std::isnan(result)) { |
| 14464 | avg_err += std::fabs(correct - result) / denom; |
| 14465 | } |
| 14466 | counter++; |
| 14467 | } |
| 14468 | } |
| 14469 | } |
| 14470 | } |
| 14471 | |
| 14472 | avg_err /= counter; |
| 14473 | |
| 14474 | if (vk_output_tensor > 0 && vk_output_tensor == check_counter) { |
| 14475 | std::cerr << "TENSOR CHECK: avg_err=" << avg_err << " in " << ggml_op_name(tensor->op) << " (check " << check_counter << ")" << std::endl; |
| 14476 | std::cerr << "tensor=" << tensor << " tensor->name=" << tensor->name << " tensor->type: " << ggml_type_name(tensor->type) << " ne0=" << tensor->ne[0] << " nb0=" << tensor->nb[0] << " ne1=" << tensor->ne[1] << " nb1=" << tensor->nb[1] << " ne2=" << tensor->ne[2] << " nb2=" << tensor->nb[2] << " ne3=" << tensor->ne[3] << " nb3=" << tensor->nb[3] << " offset=" << tensor->view_offs << std::endl; |
| 14477 | if (src0 != nullptr) { |
| 14478 | std::cerr << "src0=" << src0 << " op=" << ggml_op_name(src0->op) << " type=" << ggml_type_name(src0->type) << " ne0=" << src0->ne[0] << " nb0=" << src0->nb[0] << " ne1=" << src0->ne[1] << " nb1=" << src0->nb[1] << " ne2=" << src0->ne[2] << " nb2=" << src0->nb[2] << " ne3=" << src0->ne[3] << " nb3=" << src0->nb[3] << " offset=" << src0->view_offs << std::endl; |
| 14479 | } |
| 14480 | if (src1 != nullptr) { |
| 14481 | std::cerr << "src1=" << src1 << " op=" << ggml_op_name(src1->op) << " type=" << ggml_type_name(src1->type) << " ne0=" << src1->ne[0] << " nb0=" << src1->nb[0] << " ne1=" << src1->ne[1] << " nb1=" << src1->nb[1] << " ne2=" << src1->ne[2] << " nb2=" << src1->nb[2] << " ne3=" << src1->ne[3] << " nb3=" << src1->nb[3] << " offset=" << src1->view_offs << std::endl; |
| 14482 | } |
| 14483 | if (src2 != nullptr) { |
| 14484 | std::cerr << "src2=" << src2 << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl; |
| 14485 | } |
| 14486 | if (src3 != nullptr) { |
| 14487 | std::cerr << "src3=" << src3 << " op=" << ggml_op_name(src3->op) << " type=" << ggml_type_name(src3->type) << " ne0=" << src3->ne[0] << " nb0=" << src3->nb[0] << " ne1=" << src3->ne[1] << " nb1=" << src3->nb[1] << " ne2=" << src3->ne[2] << " nb2=" << src3->nb[2] << " ne3=" << src3->ne[3] << " nb3=" << src3->nb[3] << " offset=" << src3->view_offs << std::endl; |
| 14488 | } |
| 14489 | std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl; |
| 14490 | std::cerr << std::endl << "Result:" << std::endl; |
| 14491 | ggml_vk_print_tensor_area(tensor, tensor_data, 5, 5, 0, 0); |
| 14492 | std::cerr << std::endl << "Correct:" << std::endl; |
| 14493 | ggml_vk_print_tensor_area(tensor, comp_result, 5, 5, 0, 0); |
| 14494 | std::cerr << std::endl; |
| 14495 | std::vector<const ggml_tensor *> done; |
| 14496 | ggml_vk_print_graph_origin(tensor, done); |
| 14497 | } |
| 14498 | |
| 14499 | if (avg_err > 0.5 || std::isnan(avg_err)) { |
| 14500 | std::cerr << "ERROR: avg_err=" << avg_err << " in " << ggml_op_name(tensor->op) << " (check " << check_counter << ")" << std::endl; |
| 14501 | std::cerr << "tensor=" << tensor << " tensor->name=" << tensor->name << " tensor->type: " << ggml_type_name(tensor->type) << " ne0=" << tensor->ne[0] << " nb0=" << tensor->nb[0] << " ne1=" << tensor->ne[1] << " nb1=" << tensor->nb[1] << " ne2=" << tensor->ne[2] << " nb2=" << tensor->nb[2] << " ne3=" << tensor->ne[3] << " nb3=" << tensor->nb[3] << " offset=" << tensor->view_offs << std::endl; |
| 14502 | if (src0 != nullptr) { |
| 14503 | std::cerr << "src0=" << src0 << " op=" << ggml_op_name(src0->op) << " type=" << ggml_type_name(src0->type) << " ne0=" << src0->ne[0] << " nb0=" << src0->nb[0] << " ne1=" << src0->ne[1] << " nb1=" << src0->nb[1] << " ne2=" << src0->ne[2] << " nb2=" << src0->nb[2] << " ne3=" << src0->ne[3] << " nb3=" << src0->nb[3] << " offset=" << src0->view_offs << std::endl; |
| 14504 | } |
| 14505 | if (src1 != nullptr) { |
| 14506 | std::cerr << "src1=" << src1 << " op=" << ggml_op_name(src1->op) << " type=" << ggml_type_name(src1->type) << " ne0=" << src1->ne[0] << " nb0=" << src1->nb[0] << " ne1=" << src1->ne[1] << " nb1=" << src1->nb[1] << " ne2=" << src1->ne[2] << " nb2=" << src1->nb[2] << " ne3=" << src1->ne[3] << " nb3=" << src1->nb[3] << " offset=" << src1->view_offs << std::endl; |
| 14507 | } |
| 14508 | if (src2 != nullptr) { |
| 14509 | std::cerr << "src2=" << src2 << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl; |
| 14510 | } |
| 14511 | if (src3 != nullptr) { |
| 14512 | std::cerr << "src3=" << src3 << " op=" << ggml_op_name(src3->op) << " type=" << ggml_type_name(src3->type) << " ne0=" << src3->ne[0] << " nb0=" << src3->nb[0] << " ne1=" << src3->ne[1] << " nb1=" << src3->nb[1] << " ne2=" << src3->ne[2] << " nb2=" << src3->nb[2] << " ne3=" << src3->ne[3] << " nb3=" << src3->nb[3] << " offset=" << src3->view_offs << std::endl; |
| 14513 | } |
| 14514 | std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl; |
| 14515 | std::cerr << std::endl << "Result:" << std::endl; |
| 14516 | ggml_vk_print_tensor_area(tensor, tensor_data, first_error[0], first_error[1], first_error[2], first_error[3]); |
| 14517 | std::cerr << std::endl << "Correct:" << std::endl; |
| 14518 | ggml_vk_print_tensor_area(tensor, comp_result, first_error[0], first_error[1], first_error[2], first_error[3]); |
| 14519 | std::cerr << std::endl; |
| 14520 | std::vector<const ggml_tensor *> done; |
| 14521 | ggml_vk_print_graph_origin(tensor, done); |
| 14522 | GGML_ABORT("fatal error" ); |
| 14523 | } else { |
| 14524 | std::cerr << check_counter << " " << tensor->name << " op=" << ggml_op_name(tensor->op) << " avg_err=" << avg_err << std::endl; |
| 14525 | } |
| 14526 | |
| 14527 | free(comp_result); |
| 14528 | comp_result = nullptr; |
| 14529 | comp_size = 0; |
| 14530 | |
| 14531 | if (ggml_backend_buffer_is_vk(tensor->buffer)) { |
| 14532 | free(tensor_data); |
| 14533 | } |
| 14534 | |
| 14535 | VK_LOG_DEBUG("END ggml_vk_check_results_1(" << tensor->name << ")" ); |
| 14536 | } |
| 14537 | #endif |
| 14538 | |
| 14539 | GGML_BACKEND_DL_IMPL(ggml_backend_vk_reg) |
| 14540 | |