1#include "ggml-vulkan.h"
2#include <vulkan/vulkan_core.h>
3#if defined(GGML_VULKAN_RUN_TESTS) || defined(GGML_VULKAN_CHECK_RESULTS)
4#include <chrono>
5#include "ggml-cpu.h"
6#endif
7
8// See https://github.com/KhronosGroup/Vulkan-Hpp?tab=readme-ov-file#extensions--per-device-function-pointers-
9#define VULKAN_HPP_DISPATCH_LOADER_DYNAMIC 1
10// We use VULKAN_HPP_DEFAULT_DISPATCHER, but not VULKAN_HPP_DEFAULT_DISPATCH_LOADER_DYNAMIC_STORAGE
11// to avoid conflicts with applications or other libraries who might use it.
12#if VK_HEADER_VERSION >= 301
13namespace vk::detail { class DispatchLoaderDynamic; }
14using vk::detail::DispatchLoaderDynamic;
15#else
16namespace vk { class DispatchLoaderDynamic; }
17using vk::DispatchLoaderDynamic;
18#endif
19DispatchLoaderDynamic & ggml_vk_default_dispatcher();
20#define VULKAN_HPP_DEFAULT_DISPATCHER ggml_vk_default_dispatcher()
21
22#include <vulkan/vulkan.hpp>
23
24#include <algorithm>
25#include <cmath>
26#include <iomanip>
27#include <iostream>
28#include <tuple>
29#include <vector>
30#include <sstream>
31#include <utility>
32#include <memory>
33#include <limits>
34#include <map>
35#include <unordered_map>
36#include <memory>
37#include <mutex>
38#include <future>
39#include <thread>
40
41#if defined(_MSC_VER)
42# define NOMINMAX 1
43# include <windows.h>
44# define YIELD() YieldProcessor()
45#elif defined(__clang__) || defined(__GNUC__)
46# if defined(__x86_64__) ||defined(__i386__)
47# include <immintrin.h>
48# define YIELD() _mm_pause()
49# elif defined(__arm__) || defined(__aarch64__)
50# if defined(__clang__)
51# include <arm_acle.h>
52# define YIELD() __yield()
53# else
54# define YIELD() asm volatile("yield")
55# endif
56# endif
57#endif
58
59#if !defined(YIELD)
60#define YIELD()
61#endif
62
63#include "ggml-impl.h"
64#include "ggml-backend-impl.h"
65
66#include "ggml-vulkan-shaders.hpp"
67
68// remove this once it's more widely available in the SDK
69#if !defined(VK_KHR_shader_bfloat16)
70
71#define VK_KHR_shader_bfloat16 1
72#define VK_KHR_SHADER_BFLOAT16_SPEC_VERSION 1
73#define VK_KHR_SHADER_BFLOAT16_EXTENSION_NAME "VK_KHR_shader_bfloat16"
74#define VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_BFLOAT16_FEATURES_KHR ((VkStructureType)1000141000)
75#define VK_COMPONENT_TYPE_BFLOAT16_KHR ((VkComponentTypeKHR)1000141000)
76
77typedef struct VkPhysicalDeviceShaderBfloat16FeaturesKHR {
78 VkStructureType sType;
79 void* pNext;
80 VkBool32 shaderBFloat16Type;
81 VkBool32 shaderBFloat16DotProduct;
82 VkBool32 shaderBFloat16CooperativeMatrix;
83} VkPhysicalDeviceShaderBfloat16FeaturesKHR;
84#endif
85
86#define ROUNDUP_POW2(M, N) (((M) + (N) - 1) & ~((N) - 1))
87#define CEIL_DIV(M, N) (((M) + (N)-1) / (N))
88static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; }
89
90#define VK_VENDOR_ID_AMD 0x1002
91#define VK_VENDOR_ID_APPLE 0x106b
92#define VK_VENDOR_ID_INTEL 0x8086
93#define VK_VENDOR_ID_NVIDIA 0x10de
94
95#define VK_DEVICE_DESCRIPTOR_POOL_SIZE 256
96
97#define GGML_VK_MAX_NODES 8192
98
99#define VK_CHECK(err, msg) \
100 do { \
101 vk::Result err_ = (err); \
102 if (err_ != vk::Result::eSuccess) { \
103 fprintf(stderr, "ggml_vulkan: %s error %s at %s:%d\n", \
104 #err, to_string(err_).c_str(), __FILE__, __LINE__); \
105 exit(1); \
106 } \
107 } while (0)
108
109#ifdef GGML_VULKAN_DEBUG
110#define VK_LOG_DEBUG(msg) std::cerr << msg << std::endl
111#else
112#define VK_LOG_DEBUG(msg) ((void) 0)
113#endif // GGML_VULKAN_DEBUG
114
115struct ggml_backend_vk_context;
116
117#define MAX_PARAMETER_COUNT 12
118// Max number of adds that can be fused without exceeding MAX_PARAMETER_COUNT.
119#define MAX_FUSED_ADDS (MAX_PARAMETER_COUNT - 3)
120
121struct vk_pipeline_struct {
122 std::string name;
123 vk::ShaderModule shader_module;
124 vk::PipelineLayout layout;
125 vk::Pipeline pipeline;
126 uint32_t push_constant_size;
127 uint32_t parameter_count;
128 std::array<uint32_t, 3> wg_denoms;
129 uint32_t align;
130 // true if fields have been set by ggml_vk_create_pipeline
131 bool initialized {};
132 // set to true to request the pipeline is compiled
133 std::atomic<bool> needed {};
134 // set to true when the shader has been compiled
135 std::atomic<bool> compiled {};
136 // number of registers used, extracted from pipeline executable properties
137 uint32_t register_count {};
138};
139
140typedef std::shared_ptr<vk_pipeline_struct> vk_pipeline;
141typedef std::weak_ptr<vk_pipeline_struct> vk_pipeline_ref;
142
143static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline);
144
145struct vk_matmul_pipeline_struct {
146 vk_pipeline l, m, s;
147 vk_pipeline a_l, a_m, a_s;
148 // Returns true when all unaligned pipelines are null.
149 // We only check for unaligned variants since one of the unaligned pipelines must exist
150 // while aligned pipelines are optional
151 bool is_empty() const {
152 return l == nullptr && m == nullptr && s == nullptr;
153 }
154};
155typedef std::shared_ptr<vk_matmul_pipeline_struct> vk_matmul_pipeline;
156
157struct vk_matmul_pipeline2 {
158 vk_matmul_pipeline2() {
159 f16acc = std::make_shared<vk_matmul_pipeline_struct>();
160 f32acc = std::make_shared<vk_matmul_pipeline_struct>();
161 }
162 vk_matmul_pipeline f32acc;
163 vk_matmul_pipeline f16acc;
164};
165
166struct vk_device_struct;
167typedef std::shared_ptr<vk_device_struct> vk_device;
168typedef std::weak_ptr<vk_device_struct> vk_device_ref;
169
170struct vk_buffer_struct;
171typedef std::shared_ptr<vk_buffer_struct> vk_buffer;
172typedef std::weak_ptr<vk_buffer_struct> vk_buffer_ref;
173
174struct ggml_backend_vk_buffer_type_context {
175 std::string name;
176 vk_device device;
177};
178
179struct vk_queue;
180
181// Stores command pool/buffers. There's an instance of this
182// for each (context,queue) pair and for each (device,queue) pair.
183struct vk_command_pool {
184 void init(vk_device& device, vk_queue *q_);
185 void destroy(vk::Device& device);
186
187 vk::CommandPool pool;
188 uint32_t cmd_buffer_idx;
189 std::vector<vk::CommandBuffer> cmd_buffers;
190
191 vk_queue *q;
192};
193
194// Prevent simultaneous submissions to the same queue.
195// This could be per vk_queue if we stopped having two vk_queue structures
196// sharing the same vk::Queue.
197static std::mutex queue_mutex;
198
199struct vk_queue {
200 uint32_t queue_family_index;
201 vk::Queue queue;
202
203 vk_command_pool cmd_pool;
204
205 vk::PipelineStageFlags stage_flags;
206
207 bool transfer_only;
208
209 // copy everything except the cmd_pool
210 void copyFrom(vk_queue &other) {
211 queue_family_index = other.queue_family_index;
212 queue = other.queue;
213 stage_flags = other.stage_flags;
214 transfer_only = other.transfer_only;
215 }
216};
217
218static const char * ggml_backend_vk_buffer_type_name(ggml_backend_buffer_type_t buft);
219static ggml_backend_buffer_t ggml_backend_vk_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size);
220static size_t ggml_backend_vk_buffer_type_get_alignment(ggml_backend_buffer_type_t buft);
221static size_t ggml_backend_vk_buffer_type_get_max_size(ggml_backend_buffer_type_t buft);
222static size_t ggml_backend_vk_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor);
223static ggml_backend_buffer_type_i ggml_backend_vk_buffer_type_interface = {
224 /* .get_name = */ ggml_backend_vk_buffer_type_name,
225 /* .alloc_buffer = */ ggml_backend_vk_buffer_type_alloc_buffer,
226 /* .get_alignment = */ ggml_backend_vk_buffer_type_get_alignment,
227 /* .get_max_size = */ ggml_backend_vk_buffer_type_get_max_size,
228 /* .get_alloc_size = */ ggml_backend_vk_buffer_type_get_alloc_size,
229 /* .is_host = */ NULL,
230};
231
232#ifdef GGML_VULKAN_MEMORY_DEBUG
233class vk_memory_logger;
234#endif
235class vk_perf_logger;
236static void ggml_vk_destroy_buffer(vk_buffer& buf);
237
238static constexpr uint32_t mul_mat_vec_max_cols = 8;
239static constexpr uint32_t p021_max_gqa_ratio = 8;
240
241enum vk_device_architecture {
242 OTHER,
243 AMD_GCN,
244 AMD_RDNA1,
245 AMD_RDNA2,
246 AMD_RDNA3,
247 INTEL_XE2,
248 NVIDIA_PRE_TURING,
249};
250
251static vk_device_architecture get_device_architecture(const vk::PhysicalDevice& device) {
252 vk::PhysicalDeviceProperties props = device.getProperties();
253
254 if (props.vendorID == VK_VENDOR_ID_AMD) {
255 const std::vector<vk::ExtensionProperties> ext_props = device.enumerateDeviceExtensionProperties();
256
257 bool amd_shader_core_properties = false;
258 bool integer_dot_product = false;
259 bool subgroup_size_control = false;
260
261 for (const auto& properties : ext_props) {
262 if (strcmp(s1: "VK_AMD_shader_core_properties", s2: properties.extensionName) == 0) {
263 amd_shader_core_properties = true;
264 } else if (strcmp(s1: "VK_KHR_shader_integer_dot_product", s2: properties.extensionName) == 0) {
265 integer_dot_product = true;
266 } else if (strcmp(s1: "VK_EXT_subgroup_size_control", s2: properties.extensionName) == 0) {
267 subgroup_size_control = true;
268 }
269 }
270
271 if (!amd_shader_core_properties || !integer_dot_product || !subgroup_size_control) {
272 return vk_device_architecture::OTHER;
273 }
274
275 vk::PhysicalDeviceProperties2 props2;
276 vk::PhysicalDeviceShaderCorePropertiesAMD shader_core_props_amd;
277 vk::PhysicalDeviceShaderIntegerDotProductPropertiesKHR integer_dot_props;
278 vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props;
279
280 props2.pNext = &shader_core_props_amd;
281 shader_core_props_amd.pNext = &integer_dot_props;
282 integer_dot_props.pNext = &subgroup_size_control_props;
283
284 device.getProperties2(pProperties: &props2);
285
286 if (subgroup_size_control_props.maxSubgroupSize == 64 && subgroup_size_control_props.minSubgroupSize == 64) {
287 return vk_device_architecture::AMD_GCN;
288 }
289 if (subgroup_size_control_props.maxSubgroupSize == 64 && subgroup_size_control_props.minSubgroupSize == 32) {
290 // RDNA
291 if (shader_core_props_amd.wavefrontsPerSimd == 20) {
292 return vk_device_architecture::AMD_RDNA1;
293 }
294 if (integer_dot_props.integerDotProduct4x8BitPackedMixedSignednessAccelerated) {
295 return vk_device_architecture::AMD_RDNA3;
296 }
297 return vk_device_architecture::AMD_RDNA2;
298 }
299 } else if (props.vendorID == VK_VENDOR_ID_INTEL) {
300 const std::vector<vk::ExtensionProperties> ext_props = device.enumerateDeviceExtensionProperties();
301
302 bool subgroup_size_control = false;
303
304 for (const auto& properties : ext_props) {
305 if (strcmp(s1: "VK_EXT_subgroup_size_control", s2: properties.extensionName) == 0) {
306 subgroup_size_control = true;
307 }
308 }
309
310 if (!subgroup_size_control) {
311 return vk_device_architecture::OTHER;
312 }
313
314 vk::PhysicalDeviceProperties2 props2;
315 vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props;
316
317 props2.pNext = &subgroup_size_control_props;
318 device.getProperties2(pProperties: &props2);
319
320 if (subgroup_size_control_props.minSubgroupSize == 16) {
321 // Xe2 architecture uses SIMD16 while previous Xe and Gen architecture uses SIMD8.
322 // Minimum subgroup size matches the SIMD width so we distinguish architecture by checking this value.
323 // https://www.intel.com/content/www/us/en/content-details/824434/2024-intel-tech-tour-xe2-and-lunar-lake-s-gpu.html
324 // https://www.intel.com/content/www/us/en/docs/oneapi/optimization-guide-gpu/2025-0/intel-xe-gpu-architecture.html
325 return vk_device_architecture::INTEL_XE2;
326 }
327 } else if (props.vendorID == VK_VENDOR_ID_NVIDIA) {
328 const std::vector<vk::ExtensionProperties> ext_props = device.enumerateDeviceExtensionProperties();
329
330 bool cooperative_matrix = false;
331
332 // Detect "pre-turing" based on lack of coopmat support.
333 for (const auto& properties : ext_props) {
334 if (strcmp(s1: "VK_KHR_cooperative_matrix", s2: properties.extensionName) == 0) {
335 cooperative_matrix = true;
336 break;
337 }
338 }
339
340 if (!cooperative_matrix) {
341 return vk_device_architecture::NVIDIA_PRE_TURING;
342 }
343 }
344 return vk_device_architecture::OTHER;
345}
346
347enum vk_conv_shapes {
348 CONV_SHAPE_128x128,
349 CONV_SHAPE_64x32,
350 CONV_SHAPE_32x256,
351 CONV_SHAPE_COUNT,
352};
353
354enum dmmv_wg_sizes {
355 DMMV_WG_SIZE_SUBGROUP,
356 DMMV_WG_SIZE_LARGE,
357 DMMV_WG_SIZE_COUNT,
358};
359
360enum FaCodePath {
361 FA_SCALAR,
362 FA_COOPMAT1,
363 FA_COOPMAT2,
364};
365
366struct vk_fa_pipeline_state {
367 vk_fa_pipeline_state(uint32_t HSK, uint32_t HSV, bool small_rows, FaCodePath path, bool aligned, bool f32acc)
368 : HSK(HSK), HSV(HSV), small_rows(small_rows), path(path), aligned(aligned), f32acc(f32acc) {}
369
370 uint32_t HSK, HSV;
371 bool small_rows;
372 FaCodePath path;
373 bool aligned;
374 bool f32acc;
375
376 bool operator<(const vk_fa_pipeline_state &b) const {
377 return std::tie(args: HSK, args: HSV, args: small_rows, args: path, args: aligned, args: f32acc) <
378 std::tie(args: b.HSK, args: b.HSV, args: b.small_rows, args: b.path, args: b.aligned, args: b.f32acc);
379 }
380};
381
382enum shader_reduction_mode {
383 SHADER_REDUCTION_MODE_SHMEM,
384 SHADER_REDUCTION_MODE_HYBRID,
385 SHADER_REDUCTION_MODE_SUBGROUP,
386 SHADER_REDUCTION_MODE_COUNT,
387};
388
389static constexpr uint32_t num_argsort_pipelines = 11;
390static constexpr uint32_t max_argsort_cols = 1 << (num_argsort_pipelines-1);
391static constexpr uint32_t num_topk_moe_pipelines = 10;
392
393static constexpr std::initializer_list<ggml_op> topk_moe_early_softmax_norm{ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
394 GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
395 GGML_OP_SUM_ROWS, GGML_OP_CLAMP, GGML_OP_DIV,
396 GGML_OP_RESHAPE };
397static constexpr std::initializer_list<ggml_op> topk_moe_early_softmax { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
398 GGML_OP_VIEW, GGML_OP_GET_ROWS };
399static constexpr std::initializer_list<ggml_op> topk_moe_late_softmax { GGML_OP_ARGSORT, GGML_OP_VIEW,
400 GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
401 GGML_OP_SOFT_MAX, GGML_OP_RESHAPE };
402
403//node #978 ( SOFT_MAX): ffn_moe_probs-15 ( 0K) [Vulka ] use=2: ffn_moe_logits-15 ( 0K) [Vulka ]
404//node #979 ( RESHAPE): ffn_moe_probs-15 (re ( 0K) [Vulka ] use=1: ffn_moe_probs-15 ( 0K) [Vulka ]
405//node #980 ( ARGSORT): ffn_moe_argsort-15 ( 0K) [Vulka ] use=1: ffn_moe_probs-15 ( 0K) [Vulka ]
406//node #981 ( VIEW): ffn_moe_topk-15 ( 0K) [Vulka ] use=4: ffn_moe_argsort-15 ( 0K) [Vulka ]
407//node #982 ( GET_ROWS): ffn_moe_weights-15 ( 0K) [Vulka ] use=1: ffn_moe_probs-15 (re ( 0K) [Vulka ] ffn_moe_topk-15 ( 0K) [Vulka ]
408//node #983 ( RESHAPE): ffn_moe_weights-15 ( ( 0K) [Vulka ] use=2: ffn_moe_weights-15 ( 0K) [Vulka ]
409//node #984 ( SUM_ROWS): ffn_moe_weights_sum- ( 0K) [Vulka ] use=1: ffn_moe_weights-15 ( ( 0K) [Vulka ]
410//node #985 ( CLAMP): ffn_moe_weights_sum_ ( 0K) [Vulka ] use=1: ffn_moe_weights_sum- ( 0K) [Vulka ]
411//node #986 ( DIV): ffn_moe_weights_norm ( 0K) [Vulka ] use=1: ffn_moe_weights-15 ( ( 0K) [Vulka ] ffn_moe_weights_sum_ ( 0K) [Vulka ]
412//node #987 ( RESHAPE): ffn_moe_weights_norm ( 0K) [Vulka ] use=1: ffn_moe_weights_norm ( 0K) [Vulka ]
413static constexpr std::initializer_list<std::array<int, 3>> topk_moe_early_softmax_norm_edges {
414 { 1, 0, 0 }, // reshape->src[0] == softmax
415 { 2, 0, 0 }, // argsort->src[0] == softmax
416 { 3, 0, 2 }, // view->src[0] == argsort
417 { 4, 0, 1 }, // get_rows->src[0] == reshape
418 { 4, 1, 3 }, // get_rows->src[1] == view
419 { 5, 0, 4 }, // reshape->src[0] == get_rows
420 { 6, 0, 5 }, // sum_rows->src[0] == reshape
421 { 7, 0, 6 }, // clamp->src[0] == sum_rows
422 { 8, 0, 5 }, // div->src[0] == reshape
423 { 8, 1, 7 }, // div->src[1] == clamp
424 { 9, 0, 8 }, // reshape->src[0] == div
425};
426
427// same as early_softmax_norm but ending after the get_rows
428static constexpr std::initializer_list<std::array<int, 3>> topk_moe_early_softmax_edges {
429 { 1, 0, 0 }, // reshape->src[0] == softmax
430 { 2, 0, 0 }, // argsort->src[0] == softmax
431 { 3, 0, 2 }, // view->src[0] == argsort
432 { 4, 0, 1 }, // get_rows->src[0] == reshape
433 { 4, 1, 3 }, // get_rows->src[1] == view
434};
435
436//node #652 ( ARGSORT): ffn_moe_argsort-11 ( 0K) [Vulka ] use=1: ffn_moe_probs-11 ( 0K) [Vulka ]
437//node #653 ( VIEW): ffn_moe_topk-11 ( 0K) [Vulka ] use=7: ffn_moe_argsort-11 ( 0K) [Vulka ]
438//node #654 ( GET_ROWS): ffn_moe_weights-11 ( 0K) [Vulka ] use=1: ffn_moe_probs-11 (re ( 0K) [Vulka ] ffn_moe_topk-11 ( 0K) [Vulka ]
439//node #655 ( RESHAPE): ffn_moe_weights-11 ( ( 0K) [Vulka ] use=1: ffn_moe_weights-11 ( 0K) [Vulka ]
440//node #656 ( SOFT_MAX): node_656 ( 0K) [Vulka ] use=1: ffn_moe_weights-11 ( ( 0K) [Vulka ]
441//node #657 ( RESHAPE): ffn_moe_weights_soft ( 0K) [Vulka ] use=1: node_656 ( 0K) [Vulka ]
442static constexpr std::initializer_list<std::array<int, 3>> topk_moe_late_softmax_edges {
443 { 1, 0, 0 }, // view->src[0] == argsort
444 { 2, 1, 1 }, // get_rows->src[1] == view
445 { 3, 0, 2 }, // reshape->src[0] == get_rows
446 { 4, 0, 3 }, // soft_max->src[0] == reshape
447 { 5, 0, 4 }, // reshape->src[0] == soft_max
448};
449
450enum topk_moe_mode {
451 TOPK_MOE_EARLY_SOFTMAX,
452 TOPK_MOE_EARLY_SOFTMAX_NORM,
453 TOPK_MOE_LATE_SOFTMAX,
454 TOPK_MOE_COUNT,
455};
456
457static topk_moe_mode ggml_vk_num_additional_ops_to_topk_moe_mode(uint32_t num) {
458 topk_moe_mode mode = num == topk_moe_early_softmax_norm.size() - 1 ? TOPK_MOE_EARLY_SOFTMAX_NORM :
459 num == topk_moe_early_softmax.size() - 1 ? TOPK_MOE_EARLY_SOFTMAX :
460 TOPK_MOE_LATE_SOFTMAX;
461 return mode;
462}
463
464static constexpr std::initializer_list<std::array<int, 3>> rope_view_set_rows_edges {
465 { 1, 0, 0 }, // view->src[0] == rope
466 { 2, 0, 1 }, // set_rows->src[0] == view
467};
468
469static constexpr std::initializer_list<std::array<int, 3>> rms_norm_mul_rope_view_set_rows_edges {
470 { 1, 0, 0 }, // mul->src[0] == rms
471 { 2, 0, 1 }, // rope->src[0] == mul
472 { 3, 0, 2 }, // view->src[0] == rope
473 { 4, 0, 3 }, // set_rows->src[0] == view
474};
475
476
477struct vk_device_struct {
478 std::recursive_mutex mutex;
479
480 vk::PhysicalDevice physical_device;
481 vk::PhysicalDeviceProperties properties;
482 std::string name;
483 uint64_t max_memory_allocation_size;
484 uint64_t max_buffer_size;
485 uint64_t suballocation_block_size;
486 bool fp16;
487 bool bf16;
488 bool pipeline_robustness;
489 vk::Device device;
490 uint32_t vendor_id;
491 vk::DriverId driver_id;
492 vk_device_architecture architecture;
493 vk_queue compute_queue;
494 vk_queue transfer_queue;
495 bool single_queue;
496 uint32_t subgroup_size;
497 uint32_t shader_core_count;
498 bool uma;
499 bool prefer_host_memory;
500 bool float_controls_rte_fp16;
501 bool subgroup_arithmetic;
502 bool subgroup_shuffle;
503 bool subgroup_ballot;
504 bool subgroup_clustered;
505 bool multi_add;
506 bool shader_int64;
507 bool buffer_device_address;
508
509 bool add_rms_fusion;
510 uint32_t partials_binding_alignment;
511
512 bool integer_dot_product;
513 // 0: default, 1: force mmvq, -1: disable mmvq
514 int32_t mmvq_mode;
515
516 bool subgroup_size_control;
517 uint32_t subgroup_min_size;
518 uint32_t subgroup_max_size;
519 bool subgroup_require_full_support;
520
521 bool coopmat_support;
522 bool coopmat_acc_f32_support {};
523 bool coopmat_acc_f16_support {};
524 bool coopmat_bf16_support {};
525 bool coopmat_support_16x16x16_f16acc {};
526 bool coopmat_support_16x16x16_f32acc {};
527 bool coopmat1_fa_support {};
528 uint32_t coopmat_m;
529 uint32_t coopmat_n;
530 uint32_t coopmat_k;
531
532 bool coopmat_int_support;
533 uint32_t coopmat_int_m;
534 uint32_t coopmat_int_n;
535 uint32_t coopmat_int_k;
536
537 bool coopmat2;
538
539 bool pipeline_executable_properties_support {};
540
541 size_t idx;
542
543 bool mul_mat_l[GGML_TYPE_COUNT];
544 bool mul_mat_m[GGML_TYPE_COUNT];
545 bool mul_mat_s[GGML_TYPE_COUNT];
546 bool mul_mat_id_l[GGML_TYPE_COUNT];
547 bool mul_mat_id_m[GGML_TYPE_COUNT];
548 bool mul_mat_id_s[GGML_TYPE_COUNT];
549
550 vk::DescriptorSetLayout dsl;
551
552 vk_matmul_pipeline pipeline_matmul_f32 {};
553 vk_matmul_pipeline pipeline_matmul_f32_f16 {};
554 vk_matmul_pipeline pipeline_matmul_bf16 {};
555 vk_matmul_pipeline2 pipeline_matmul_f16;
556 vk_matmul_pipeline2 pipeline_matmul_f16_f32;
557
558 vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat[GGML_TYPE_COUNT];
559 vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_COUNT];
560 vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_COUNT];
561
562 vk_matmul_pipeline pipeline_matmul_id_f32 {};
563 vk_matmul_pipeline pipeline_matmul_id_bf16 {};
564 vk_matmul_pipeline2 pipeline_matmul_id_f16;
565 vk_matmul_pipeline2 pipeline_matmul_id_f16_f32;
566
567 vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_id[GGML_TYPE_COUNT];
568 vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_COUNT];
569
570 vk_pipeline pipeline_matmul_split_k_reduce;
571 vk_pipeline pipeline_quantize_q8_1;
572 vk_pipeline pipeline_quantize_q8_1_x4;
573
574 vk_pipeline pipeline_dequant[GGML_TYPE_COUNT];
575 vk_pipeline pipeline_dequant_mul_mat_vec_f32_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT][mul_mat_vec_max_cols];
576 vk_pipeline pipeline_dequant_mul_mat_vec_f16_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT][mul_mat_vec_max_cols];
577 vk_pipeline pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_COUNT];
578
579 vk_pipeline pipeline_dequant_mul_mat_vec_q8_1_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT][mul_mat_vec_max_cols];
580
581 vk_pipeline pipeline_mul_mat_vec_p021_f16_f32[p021_max_gqa_ratio];
582 vk_pipeline pipeline_mul_mat_vec_nc_f16_f32;
583 vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT];
584 vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT];
585 vk_pipeline pipeline_acc_f32;
586
587 // [src0 0=fp32,1=fp16][src1 0=fp32,1=fp16][dst 0=fp32,1=fp16]
588 vk_pipeline pipeline_add[2][2][2];
589 vk_pipeline pipeline_add_norepeat[2][2][2];
590 vk_pipeline pipeline_sub[2][2][2];
591 vk_pipeline pipeline_sub_norepeat[2][2][2];
592 vk_pipeline pipeline_mul[2][2][2];
593 vk_pipeline pipeline_mul_norepeat[2][2][2];
594 vk_pipeline pipeline_div[2][2][2];
595 vk_pipeline pipeline_div_norepeat[2][2][2];
596 vk_pipeline pipeline_add_rms[2][2][2];
597 vk_pipeline pipeline_add_rms_norepeat[2][2][2];
598
599 // indexed by num_additional_fused_ops == num_adds - 1
600 vk_pipeline pipeline_multi_add[MAX_FUSED_ADDS];
601 vk_pipeline pipeline_multi_add_rms[MAX_FUSED_ADDS];
602
603 vk_pipeline pipeline_add_id_f32;
604
605 vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32;
606 vk_pipeline pipeline_upscale_nearest_f32, pipeline_upscale_bilinear_f32;
607 vk_pipeline pipeline_scale_f32;
608 vk_pipeline pipeline_sqr_f32;
609 vk_pipeline pipeline_sqrt_f32;
610 vk_pipeline pipeline_sin_f32;
611 vk_pipeline pipeline_cos_f32;
612 vk_pipeline pipeline_clamp_f32;
613 vk_pipeline pipeline_pad_f32;
614 vk_pipeline pipeline_roll_f32;
615 vk_pipeline pipeline_repeat_f32, pipeline_repeat_back_f32;
616 vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16, pipeline_cpy_f16_f32, pipeline_cpy_f32_bf16, pipeline_cpy_f32_i32, pipeline_cpy_i32_f32;
617 vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16, pipeline_contig_cpy_f32_i32, pipeline_contig_cpy_i32_f32;
618 vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT];
619 vk_pipeline pipeline_cpy_quant_f32[GGML_TYPE_COUNT];
620 vk_pipeline pipeline_set_rows_i32[GGML_TYPE_COUNT];
621 vk_pipeline pipeline_set_rows_i64[GGML_TYPE_COUNT];
622 vk_pipeline pipeline_norm_f32;
623 vk_pipeline pipeline_group_norm_f32;
624 vk_pipeline pipeline_rms_norm_f32;
625 vk_pipeline pipeline_rms_norm_mul_f32;
626 vk_pipeline pipeline_rms_norm_partials_f32;
627 vk_pipeline pipeline_rms_norm_mul_partials_f32;
628 vk_pipeline pipeline_rms_norm_mul_rope_f32_f32;
629 vk_pipeline pipeline_rms_norm_mul_rope_f32_f16;
630 vk_pipeline pipeline_rms_norm_back_f32;
631 vk_pipeline pipeline_l2_norm_f32;
632
633 // [src/dst 0=fp32,1=fp16]
634 vk_pipeline pipeline_exp[2];
635 vk_pipeline pipeline_gelu[2];
636 vk_pipeline pipeline_gelu_erf[2];
637 vk_pipeline pipeline_gelu_quick[2];
638 vk_pipeline pipeline_silu[2];
639 vk_pipeline pipeline_relu[2];
640 vk_pipeline pipeline_tanh[2];
641 vk_pipeline pipeline_sigmoid[2];
642 vk_pipeline pipeline_hardsigmoid[2];
643 vk_pipeline pipeline_hardswish[2];
644
645 vk_pipeline pipeline_geglu[2];
646 vk_pipeline pipeline_reglu[2];
647 vk_pipeline pipeline_swiglu[2];
648 vk_pipeline pipeline_swiglu_oai[2];
649 vk_pipeline pipeline_geglu_erf[2];
650 vk_pipeline pipeline_geglu_quick[2];
651
652 vk_pipeline pipeline_leaky_relu_f32;
653 vk_pipeline pipeline_silu_back_f32;
654 vk_pipeline pipeline_diag_mask_inf_f32;
655 vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16;
656 vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512;
657 vk_pipeline pipeline_soft_max_back_f32;
658 vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16, pipeline_rope_norm_f32_f16;
659 vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16, pipeline_rope_neox_f32_f16;
660 vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16;
661 vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16;
662 vk_pipeline pipeline_argsort_f32[num_argsort_pipelines];
663 vk_pipeline pipeline_sum_rows_f32;
664 vk_pipeline pipeline_argmax_f32;
665 vk_pipeline pipeline_count_equal_i32;
666 vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
667 vk_pipeline pipeline_im2col_3d_f32, pipeline_im2col_3d_f32_f16;
668 vk_pipeline pipeline_timestep_embedding_f32;
669 vk_pipeline pipeline_conv_transpose_1d_f32;
670 vk_pipeline pipeline_pool2d_f32;
671 vk_pipeline pipeline_rwkv_wkv6_f32;
672 vk_pipeline pipeline_rwkv_wkv7_f32;
673 vk_pipeline pipeline_ssm_scan_f32_d128;
674 vk_pipeline pipeline_ssm_scan_f32_d256;
675 vk_pipeline pipeline_ssm_conv_f32;
676 vk_pipeline pipeline_opt_step_adamw_f32;
677 vk_pipeline pipeline_opt_step_sgd_f32;
678 vk_pipeline pipeline_conv2d_f32[CONV_SHAPE_COUNT];
679 vk_pipeline pipeline_conv2d_f16_f32[CONV_SHAPE_COUNT];
680 vk_pipeline pipeline_conv_transpose_2d_f32[CONV_SHAPE_COUNT];
681 vk_pipeline pipeline_conv_transpose_2d_f16_f32[CONV_SHAPE_COUNT];
682 vk_pipeline pipeline_conv2d_dw_whcn_f32, pipeline_conv2d_dw_whcn_f16_f32;
683 vk_pipeline pipeline_conv2d_dw_cwhn_f32, pipeline_conv2d_dw_cwhn_f16_f32;
684
685 std::map<vk_fa_pipeline_state, vk_pipeline> pipeline_flash_attn_f32_f16[GGML_TYPE_COUNT];
686
687 vk_pipeline pipeline_flash_attn_split_k_reduce;
688
689 vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][TOPK_MOE_COUNT];
690
691 std::vector<vk_pipeline_ref> all_pipelines;
692
693 std::vector<std::tuple<void*, size_t, vk_buffer>> pinned_memory;
694
695 vk::Fence fence;
696 vk_buffer sync_staging;
697
698 ggml_backend_buffer_type buffer_type;
699
700 bool disable_fusion;
701 bool disable_host_visible_vidmem;
702 bool allow_sysmem_fallback;
703 bool disable_graph_optimize;
704
705#ifdef GGML_VULKAN_MEMORY_DEBUG
706 std::unique_ptr<vk_memory_logger> memory_logger;
707#endif
708
709 // for GGML_VK_PERF_LOGGER
710 std::unique_ptr<vk_perf_logger> perf_logger;
711 vk::QueryPool query_pool;
712 int32_t num_queries;
713
714 ~vk_device_struct() {
715 VK_LOG_DEBUG("destroy device " << name);
716
717 device.destroyFence(fence);
718
719 ggml_vk_destroy_buffer(buf&: sync_staging);
720
721 compute_queue.cmd_pool.destroy(device);
722 transfer_queue.cmd_pool.destroy(device);
723
724 for (auto& pipeline : all_pipelines) {
725 if (pipeline.expired()) {
726 continue;
727 }
728
729 vk_pipeline pl = pipeline.lock();
730 ggml_vk_destroy_pipeline(device, pipeline&: pl);
731 }
732 all_pipelines.clear();
733
734 device.destroyDescriptorSetLayout(descriptorSetLayout: dsl);
735
736 device.destroy();
737 }
738};
739
740void vk_command_pool::init(vk_device& device, vk_queue *q_) {
741 cmd_buffer_idx = 0;
742 q = q_;
743
744 vk::CommandPoolCreateInfo command_pool_create_info(vk::CommandPoolCreateFlags(VK_COMMAND_POOL_CREATE_TRANSIENT_BIT), q->queue_family_index);
745 pool = device->device.createCommandPool(createInfo: command_pool_create_info);
746}
747
748void vk_command_pool::destroy(vk::Device& device) {
749 device.destroyCommandPool(commandPool: pool);
750 pool = nullptr;
751 cmd_buffers.clear();
752}
753
754struct vk_buffer_struct {
755 vk::Buffer buffer = VK_NULL_HANDLE;
756 vk::DeviceMemory device_memory = VK_NULL_HANDLE;
757 vk::MemoryPropertyFlags memory_property_flags;
758 void * ptr;
759 size_t size = 0;
760 vk::DeviceAddress bda_addr {};
761
762 vk_device device;
763
764 ~vk_buffer_struct() {
765 if (size == 0) {
766 return;
767 }
768 VK_LOG_DEBUG("~vk_buffer_struct(" << buffer << ", " << size << ")");
769
770 device->device.freeMemory(memory: device_memory);
771 device->device.destroyBuffer(buffer);
772 }
773};
774
775struct vk_subbuffer {
776 vk_buffer buffer;
777 uint64_t offset;
778 uint64_t size;
779
780 operator vk::DescriptorBufferInfo() const {
781 return { buffer->buffer, offset, size };
782 }
783};
784
785struct vk_semaphore {
786 vk::Semaphore s;
787 uint64_t value;
788};
789
790struct vk_submission {
791 vk::CommandBuffer buffer;
792 std::vector<vk_semaphore> wait_semaphores;
793 std::vector<vk_semaphore> signal_semaphores;
794};
795
796typedef std::vector<vk_submission> vk_sequence;
797
798struct vk_mat_mat_push_constants {
799 uint32_t M; uint32_t N; uint32_t K;
800 uint32_t stride_a; uint32_t stride_b; uint32_t stride_d;
801 uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d;
802 uint32_t k_split;
803 uint32_t ne02; uint32_t ne12; uint32_t broadcast2; uint32_t broadcast3;
804 uint32_t padded_N;
805};
806struct vk_mat_vec_push_constants {
807 uint32_t ncols;
808 uint32_t stride_a;
809 uint32_t stride_b;
810 uint32_t stride_d;
811 uint32_t batch_stride_a;
812 uint32_t batch_stride_b;
813 uint32_t batch_stride_d;
814 uint32_t enable_bias;
815 uint32_t ne02;
816 uint32_t ne12;
817 uint32_t broadcast2;
818 uint32_t broadcast3;
819};
820
821struct vk_mat_mat_id_push_constants {
822 uint32_t M; uint32_t N; uint32_t K;
823 uint32_t stride_a; uint32_t stride_b; uint32_t stride_d;
824 uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d;
825 uint32_t nei0; uint32_t nei1; uint32_t nbi1; uint32_t ne11;
826 uint32_t padded_N;
827};
828struct vk_mat_vec_id_push_constants {
829 uint32_t ncols;
830 uint32_t stride_a;
831 uint32_t stride_b;
832 uint32_t stride_d;
833 uint32_t batch_stride_a;
834 uint32_t batch_stride_b;
835 uint32_t batch_stride_d;
836 uint32_t enable_bias;
837 uint32_t nei0;
838 uint32_t ne11;
839};
840
841struct vk_flash_attn_push_constants {
842 uint32_t N;
843 uint32_t KV;
844
845 uint32_t ne1;
846 uint32_t ne2;
847 uint32_t ne3;
848
849 uint32_t neq2;
850 uint32_t neq3;
851 uint32_t nek2;
852 uint32_t nek3;
853 uint32_t nev2;
854 uint32_t nev3;
855 uint32_t nem1;
856 uint32_t nem2;
857 uint32_t nem3;
858
859 uint32_t nb01;
860 uint32_t nb02;
861 uint32_t nb03;
862 uint32_t nb11;
863 uint32_t nb12;
864 uint32_t nb13;
865 uint32_t nb21;
866 uint32_t nb22;
867 uint32_t nb23;
868
869 float scale;
870 float max_bias;
871 float logit_softcap;
872
873 uint32_t mask_n_head_log2;
874 float m0;
875 float m1;
876
877 uint32_t gqa_ratio;
878 uint32_t split_kv;
879 uint32_t k_num;
880};
881static_assert(sizeof(vk_flash_attn_push_constants) <= 128, "sizeof(vk_flash_attn_push_constants) must be <= 128");
882
883struct vk_op_push_constants {
884 uint32_t KX;
885 uint32_t KY;
886 float param1;
887 float param2;
888};
889
890struct vk_op_glu_push_constants {
891 uint32_t N;
892 uint32_t ne00;
893 uint32_t ne20;
894 uint32_t mode; // 0: default, 1: swapped, 2: split
895 float alpha; // for swiglu_oai
896 float limit;
897};
898
899struct vk_op_unary_push_constants {
900 uint32_t ne;
901 uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
902 uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13;
903 uint32_t misalign_offsets;
904 float param1; float param2;
905 uint32_t ne0_012mp; uint32_t ne0_012L;
906 uint32_t ne0_01mp; uint32_t ne0_01L;
907 uint32_t ne0_0mp; uint32_t ne0_0L;
908 uint32_t ne1_012mp; uint32_t ne1_012L;
909 uint32_t ne1_01mp; uint32_t ne1_01L;
910 uint32_t ne1_0mp; uint32_t ne1_0L;
911};
912static_assert(sizeof(vk_op_unary_push_constants) <= 128, "sizeof(vk_op_unary_push_constants) must be <= 128");
913
914static vk_op_unary_push_constants vk_op_unary_push_constants_init(const ggml_tensor * src0, const ggml_tensor * dst, int64_t ne = 0) {
915 GGML_ASSERT(ne != 0 || (ggml_nelements(src0) == ggml_nelements(dst)));
916 ne = ne != 0 ? ne : ggml_nelements(tensor: dst);
917 GGML_ASSERT(ne <= (int64_t)std::numeric_limits<uint32_t>::max());
918
919 vk_op_unary_push_constants p{};
920 p.ne = (uint32_t)ne;
921
922 size_t src0_tsize = ggml_type_size(type: src0->type);
923 p.ne00 = (uint32_t)src0->ne[0];
924 p.ne01 = (uint32_t)src0->ne[1];
925 p.ne02 = (uint32_t)src0->ne[2];
926 p.ne03 = (uint32_t)src0->ne[3];
927 p.nb00 = (uint32_t)(src0->nb[0] / src0_tsize);
928 p.nb01 = (uint32_t)(src0->nb[1] / src0_tsize);
929 p.nb02 = (uint32_t)(src0->nb[2] / src0_tsize);
930 p.nb03 = (uint32_t)(src0->nb[3] / src0_tsize);
931
932 size_t dst_tsize = ggml_type_size(type: dst->type);
933 p.ne10 = (uint32_t)dst->ne[0];
934 p.ne11 = (uint32_t)dst->ne[1];
935 p.ne12 = (uint32_t)dst->ne[2];
936 p.ne13 = (uint32_t)dst->ne[3];
937 p.nb10 = (uint32_t)(dst->nb[0] / dst_tsize);
938 p.nb11 = (uint32_t)(dst->nb[1] / dst_tsize);
939 p.nb12 = (uint32_t)(dst->nb[2] / dst_tsize);
940 p.nb13 = (uint32_t)(dst->nb[3] / dst_tsize);
941
942 return p; // offsets are initialized later in ggml_vk_op
943}
944
945struct vk_op_pad_push_constants {
946 uint32_t ne;
947 uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
948 uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13;
949 uint32_t misalign_offsets;
950
951 uint32_t lp0; uint32_t rp0;
952 uint32_t lp1; uint32_t rp1;
953 uint32_t lp2; uint32_t rp2;
954 uint32_t lp3; uint32_t rp3;
955};
956
957static vk_op_pad_push_constants vk_op_pad_push_constants_init(const ggml_tensor * src0, const ggml_tensor * dst) {
958 int64_t ne = ggml_nelements(tensor: dst);
959 GGML_ASSERT(ne <= (int64_t)std::numeric_limits<uint32_t>::max());
960
961 vk_op_pad_push_constants p{};
962 p.ne = (uint32_t)ne;
963
964 size_t src0_tsize = ggml_type_size(type: src0->type);
965 p.ne00 = (uint32_t)src0->ne[0];
966 p.ne01 = (uint32_t)src0->ne[1];
967 p.ne02 = (uint32_t)src0->ne[2];
968 p.ne03 = (uint32_t)src0->ne[3];
969 p.nb00 = (uint32_t)(src0->nb[0] / src0_tsize);
970 p.nb01 = (uint32_t)(src0->nb[1] / src0_tsize);
971 p.nb02 = (uint32_t)(src0->nb[2] / src0_tsize);
972 p.nb03 = (uint32_t)(src0->nb[3] / src0_tsize);
973
974 size_t dst_tsize = ggml_type_size(type: dst->type);
975 p.ne10 = (uint32_t)dst->ne[0];
976 p.ne11 = (uint32_t)dst->ne[1];
977 p.ne12 = (uint32_t)dst->ne[2];
978 p.ne13 = (uint32_t)dst->ne[3];
979 p.nb10 = (uint32_t)(dst->nb[0] / dst_tsize);
980 p.nb11 = (uint32_t)(dst->nb[1] / dst_tsize);
981 p.nb12 = (uint32_t)(dst->nb[2] / dst_tsize);
982 p.nb13 = (uint32_t)(dst->nb[3] / dst_tsize);
983
984 p.lp0 = dst->op_params[0];
985 p.rp0 = dst->op_params[1];
986 p.lp1 = dst->op_params[2];
987 p.rp1 = dst->op_params[3];
988 p.lp2 = dst->op_params[4];
989 p.rp2 = dst->op_params[5];
990 p.lp3 = dst->op_params[6];
991 p.rp3 = dst->op_params[7];
992
993 return p; // fastdiv values and offsets are initialized later in ggml_vk_op
994}
995
996// See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1.
997// Precompute mp (m' in the paper) and L such that division
998// can be computed using a multiply (high 32b of 64b result)
999// and a shift:
1000//
1001// n/d = (mulhi(n, mp) + n) >> L;
1002static void init_fastdiv_values(uint32_t d, uint32_t &mp, uint32_t &L)
1003{
1004 // compute L = ceil(log2(d));
1005 L = 0;
1006 while (L < 32 && (uint32_t{1} << L) < d) {
1007 L++;
1008 }
1009
1010 mp = (uint32_t)((uint64_t{1} << 32) * ((uint64_t{1} << L) - d) / d + 1);
1011}
1012
1013template <typename T> void init_pushconst_fastdiv(T &p) {
1014 GGML_UNUSED(p);
1015 static_assert(!std::is_const<T>::value, "unexpected type");
1016}
1017
1018template <> void init_pushconst_fastdiv(vk_op_unary_push_constants &p) {
1019 // Compute magic values to divide by these six numbers.
1020 init_fastdiv_values(d: p.ne02*p.ne01*p.ne00, mp&: p.ne0_012mp, L&: p.ne0_012L);
1021 init_fastdiv_values(d: p.ne01*p.ne00, mp&: p.ne0_01mp, L&: p.ne0_01L);
1022 init_fastdiv_values(d: p.ne00, mp&: p.ne0_0mp, L&: p.ne0_0L);
1023 init_fastdiv_values(d: p.ne12*p.ne11*p.ne10, mp&: p.ne1_012mp, L&: p.ne1_012L);
1024 init_fastdiv_values(d: p.ne11*p.ne10, mp&: p.ne1_01mp, L&: p.ne1_01L);
1025 init_fastdiv_values(d: p.ne10, mp&: p.ne1_0mp, L&: p.ne1_0L);
1026}
1027
1028struct vk_op_binary_push_constants {
1029 uint32_t ne;
1030 uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
1031 uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13;
1032 uint32_t ne20; uint32_t ne21; uint32_t ne22; uint32_t ne23; uint32_t nb20; uint32_t nb21; uint32_t nb22; uint32_t nb23;
1033 uint32_t misalign_offsets;
1034 float param1; float param2; int32_t param3;
1035};
1036
1037struct vk_op_multi_add_push_constants {
1038 // shape for dst
1039 uint32_t ne20; uint32_t ne21; uint32_t ne22; uint32_t ne23;
1040
1041 // strides for srcs+dst
1042 uint32_t nb[MAX_PARAMETER_COUNT][4];
1043
1044 uint32_t rms_partials;
1045};
1046// update multi_add.comp if this changes
1047static_assert(MAX_PARAMETER_COUNT == 12);
1048static_assert(sizeof(vk_op_multi_add_push_constants) <= 256);
1049
1050struct vk_op_topk_moe_push_constants {
1051 uint32_t n_rows;
1052 uint32_t n_expert_used;
1053 float clamp_min;
1054 float clamp_max;
1055};
1056
1057struct vk_op_add_id_push_constants {
1058 uint32_t ne0;
1059 uint32_t ne1;
1060 uint32_t s01;
1061 uint32_t s02;
1062 uint32_t s11;
1063 uint32_t s21;
1064};
1065
1066struct vk_op_diag_mask_push_constants {
1067 uint32_t ncols;
1068 uint32_t rows_per_channel;
1069 int32_t n_past;
1070};
1071
1072struct vk_op_rope_push_constants {
1073 uint32_t rope_mode;
1074 uint32_t ncols;
1075 uint32_t n_dims;
1076 float freq_scale;
1077 uint32_t p_delta_rows;
1078 float freq_base;
1079 float ext_factor;
1080 float attn_factor;
1081 float corr_dims[2];
1082 float theta_scale;
1083 uint32_t has_ff;
1084 uint32_t ne02;
1085 uint32_t s1;
1086 uint32_t s2;
1087 int32_t sections[4];
1088 uint32_t is_imrope;
1089 uint32_t is_back;
1090 uint32_t set_rows_stride;
1091};
1092
1093// For fused rms_norm+mul+rope(+view+set_rows)
1094struct vk_op_rms_norm_mul_rope_push_constants {
1095 vk_op_binary_push_constants bin;
1096 vk_op_rope_push_constants rope;
1097};
1098
1099struct vk_op_soft_max_push_constants {
1100 uint32_t KX;
1101 uint32_t KY;
1102 uint32_t ne00;
1103 uint32_t ne01;
1104 uint32_t ne02;
1105 uint32_t ne12;
1106 uint32_t ne13;
1107 uint32_t nb11;
1108 uint32_t nb12;
1109 uint32_t nb13;
1110 float scale;
1111 float max_bias;
1112 float m0;
1113 float m1;
1114 uint32_t n_head_log2;
1115 uint32_t nrows_x;
1116 uint32_t has_sinks;
1117};
1118
1119struct vk_op_argsort_push_constants {
1120 uint32_t ncols;
1121 uint32_t nrows;
1122 int32_t order;
1123};
1124
1125struct vk_op_im2col_push_constants {
1126 uint64_t dst_addr;
1127 uint32_t batch_offset; uint32_t offset_delta;
1128 uint32_t IC;
1129 uint32_t IW; uint32_t IH;
1130 uint32_t OW; uint32_t OH;
1131 uint32_t KW; uint32_t KH;
1132 uint32_t pelements;
1133 uint32_t CHW;
1134 int32_t s0; int32_t s1;
1135 int32_t p0; int32_t p1;
1136 int32_t d0; int32_t d1;
1137};
1138
1139struct vk_op_im2col_3d_push_constants {
1140 uint64_t dst_addr;
1141 uint32_t nb10;
1142 uint32_t nb11;
1143 uint32_t nb12;
1144 uint32_t nb13;
1145 uint32_t s0;
1146 uint32_t s1;
1147 uint32_t s2;
1148 uint32_t p0;
1149 uint32_t p1;
1150 uint32_t p2;
1151 uint32_t d0;
1152 uint32_t d1;
1153 uint32_t d2;
1154 uint32_t IW;
1155 uint32_t IH;
1156 uint32_t ID;
1157 uint32_t IC;
1158 uint32_t KW;
1159 uint32_t OH;
1160 uint32_t KD_KH_KW;
1161 uint32_t KH_KW;
1162 uint32_t IC_KD_KH_KW;
1163 uint32_t N_OD_OH;
1164 uint32_t OD_OH;
1165 uint32_t OD_OH_OW_IC_KD_KH_KW;
1166 uint32_t OH_OW_IC_KD_KH_KW;
1167 uint32_t OW_IC_KD_KH_KW;
1168 uint32_t misalign_offsets;
1169};
1170
1171struct vk_op_timestep_embedding_push_constants {
1172 uint32_t nb1;
1173 uint32_t dim;
1174 uint32_t max_period;
1175};
1176
1177struct vk_op_conv_transpose_1d_push_constants {
1178 uint32_t Cout;
1179 uint32_t Cin;
1180 uint32_t K;
1181 uint32_t L;
1182 uint32_t KL;
1183
1184 uint32_t nb01;
1185 uint32_t nb02;
1186 uint32_t nb11;
1187 uint32_t nb1;
1188
1189 int32_t s0;
1190};
1191
1192struct vk_op_pool2d_push_constants {
1193 uint32_t IW; uint32_t IH;
1194 uint32_t OW; uint32_t OH;
1195 uint32_t OC;
1196 uint32_t pelements;
1197 uint32_t op;
1198 int32_t k0; int32_t k1;
1199 int32_t s0; int32_t s1;
1200 int32_t p0; int32_t p1;
1201};
1202
1203struct vk_op_rwkv_wkv6_push_constants {
1204 uint32_t B;
1205 uint32_t T;
1206 uint32_t C;
1207 uint32_t H;
1208};
1209
1210struct vk_op_rwkv_wkv7_push_constants {
1211 uint32_t B;
1212 uint32_t T;
1213 uint32_t C;
1214 uint32_t H;
1215};
1216struct vk_op_ssm_scan_push_constants {
1217 uint32_t nb02, nb03, nb12, nb13;
1218 uint32_t nb21, nb22, nb31;
1219 uint32_t nb42, nb43, nb52, nb53;
1220 uint32_t s_off;
1221 uint32_t n_head, d_head, n_group, n_tok;
1222};
1223struct vk_op_ssm_conv_push_constants {
1224 uint32_t nb01, nb02;
1225 uint32_t nb11;
1226 uint32_t dst_nb0, dst_nb1, dst_nb2;
1227 uint32_t nc, ncs, nr, n_t, n_s;
1228};
1229
1230struct vk_op_conv2d_push_constants {
1231 uint32_t Cout;
1232 uint32_t Cin;
1233 uint32_t N;
1234
1235 uint32_t KW;
1236 uint32_t KH;
1237 uint32_t W;
1238 uint32_t H;
1239 uint32_t OW;
1240 uint32_t OH;
1241
1242 uint32_t s0;
1243 uint32_t s1;
1244 uint32_t p0;
1245 uint32_t p1;
1246 uint32_t d0;
1247 uint32_t d1;
1248
1249 uint32_t nb01;
1250 uint32_t nb02;
1251 uint32_t nb03;
1252
1253 uint32_t nb11;
1254 uint32_t nb12;
1255 uint32_t nb13;
1256
1257 uint32_t nb1;
1258 uint32_t nb2;
1259 uint32_t nb3;
1260
1261 // init_fastdiv_values constants for dividing by KW, KW*KH, OW, OW*OH
1262 uint32_t KWmp; uint32_t KWL;
1263 uint32_t KWKHmp; uint32_t KWKHL;
1264 uint32_t OWmp; uint32_t OWL;
1265 uint32_t OWOHmp; uint32_t OWOHL;
1266};
1267
1268template <> void init_pushconst_fastdiv(vk_op_conv2d_push_constants &p) {
1269 // Compute magic values to divide by KW, KW*KH, OW, OW*OH
1270 init_fastdiv_values(d: p.KW, mp&: p.KWmp, L&: p.KWL);
1271 init_fastdiv_values(d: p.KW*p.KH, mp&: p.KWKHmp, L&: p.KWKHL);
1272 init_fastdiv_values(d: p.OW, mp&: p.OWmp, L&: p.OWL);
1273 init_fastdiv_values(d: p.OW*p.OH, mp&: p.OWOHmp, L&: p.OWOHL);
1274}
1275
1276struct vk_op_conv_transpose_2d_push_constants {
1277 uint32_t Cout;
1278 uint32_t Cin;
1279 uint32_t N;
1280
1281 uint32_t KW;
1282 uint32_t KH;
1283 uint32_t W;
1284 uint32_t H;
1285 uint32_t OW;
1286 uint32_t OH;
1287
1288 uint32_t s0;
1289 uint32_t s1;
1290 uint32_t p0;
1291 uint32_t p1;
1292 uint32_t d0;
1293 uint32_t d1;
1294
1295 uint32_t nb01;
1296 uint32_t nb02;
1297 uint32_t nb03;
1298
1299 uint32_t nb11;
1300 uint32_t nb12;
1301 uint32_t nb13;
1302
1303 uint32_t nb1;
1304 uint32_t nb2;
1305 uint32_t nb3;
1306
1307 // init_fastdiv_values constants for dividing by KW, KW*KH, OW, OW*OH, s0, s1
1308 uint32_t KWmp; uint32_t KWL;
1309 uint32_t KWKHmp; uint32_t KWKHL;
1310 uint32_t OWmp; uint32_t OWL;
1311 uint32_t OWOHmp; uint32_t OWOHL;
1312 uint32_t s0mp; uint32_t s0L;
1313 uint32_t s1mp; uint32_t s1L;
1314};
1315
1316template <> void init_pushconst_fastdiv(vk_op_conv_transpose_2d_push_constants &p) {
1317 // Compute magic values to divide by KW, KW*KH, OW, OW*OH, s0, s1
1318 init_fastdiv_values(d: p.KW, mp&: p.KWmp, L&: p.KWL);
1319 init_fastdiv_values(d: p.KW*p.KH, mp&: p.KWKHmp, L&: p.KWKHL);
1320 init_fastdiv_values(d: p.OW, mp&: p.OWmp, L&: p.OWL);
1321 init_fastdiv_values(d: p.OW*p.OH, mp&: p.OWOHmp, L&: p.OWOHL);
1322 init_fastdiv_values(d: p.s0, mp&: p.s0mp, L&: p.s0L);
1323 init_fastdiv_values(d: p.s1, mp&: p.s1mp, L&: p.s1L);
1324}
1325
1326struct vk_op_conv2d_dw_push_constants {
1327 uint32_t ne;
1328 uint32_t batches;
1329 uint32_t channels;
1330 uint32_t dst_w;
1331 uint32_t dst_h;
1332 uint32_t src_w;
1333 uint32_t src_h;
1334 uint32_t knl_w;
1335 uint32_t knl_h;
1336 int32_t stride_x;
1337 int32_t stride_y;
1338 int32_t pad_x;
1339 int32_t pad_y;
1340 int32_t dilation_x;
1341 int32_t dilation_y;
1342};
1343
1344struct vk_op_upscale_push_constants {
1345 uint32_t ne; uint32_t a_offset; uint32_t d_offset;
1346 uint32_t ne00; uint32_t ne01;
1347 uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
1348 uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13;
1349 float sf0; float sf1; float sf2; float sf3;
1350 float pixel_offset;
1351};
1352
1353struct vk_op_sum_rows_push_constants
1354{
1355 uint32_t n_cols;
1356 uint32_t ne01, ne02;
1357 uint32_t nb01, nb02, nb03;
1358 uint32_t nb11, nb12, nb13;
1359 float weight;
1360 uint32_t misalign_offsets;
1361 uint32_t ne0_12mp, ne0_12L;
1362 uint32_t ne0_1mp, ne0_1L;
1363};
1364
1365static vk_op_sum_rows_push_constants vk_op_sum_rows_push_constants_init(const ggml_tensor * src, const ggml_tensor * dst, int64_t n_cols) {
1366 uint32_t type_size = (uint32_t)ggml_type_size(type: src->type);
1367 vk_op_sum_rows_push_constants p = {};
1368 p.n_cols = (uint32_t)n_cols;
1369 p.ne01 = (uint32_t)src->ne[1];
1370 p.ne02 = (uint32_t)src->ne[2];
1371 p.nb01 = (uint32_t)src->nb[1] / type_size;
1372 p.nb02 = (uint32_t)src->nb[2] / type_size;
1373 p.nb03 = (uint32_t)src->nb[3] / type_size;
1374 p.nb11 = (uint32_t)dst->nb[1] / type_size;
1375 p.nb12 = (uint32_t)dst->nb[2] / type_size;
1376 p.nb13 = (uint32_t)dst->nb[3] / type_size;
1377 p.weight = 1.0f;
1378 return p;
1379}
1380
1381template <> void init_pushconst_fastdiv(vk_op_sum_rows_push_constants &p) {
1382 init_fastdiv_values(d: p.ne01*p.ne02, mp&: p.ne0_12mp, L&: p.ne0_12L);
1383 init_fastdiv_values(d: p.ne01, mp&: p.ne0_1mp, L&: p.ne0_1L);
1384}
1385
1386// Allow pre-recording command buffers
1387struct vk_staging_memcpy {
1388 vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}
1389
1390 void * dst;
1391 const void * src;
1392 size_t n;
1393};
1394
1395struct vk_staging_memset {
1396 vk_staging_memset(void * _dst, uint32_t _val, size_t _n) : dst(_dst), val(_val), n(_n) {}
1397
1398 void * dst;
1399 uint32_t val;
1400 size_t n;
1401};
1402
1403struct vk_context_struct {
1404 vk_submission * s;
1405 std::vector<vk_sequence> seqs;
1406
1407 int exit_tensor_idx;
1408
1409 std::vector<vk_staging_memcpy> in_memcpys;
1410 std::vector<vk_staging_memcpy> out_memcpys;
1411 std::vector<vk_staging_memset> memsets;
1412
1413 vk_command_pool * p {};
1414};
1415typedef std::shared_ptr<vk_context_struct> vk_context;
1416typedef std::weak_ptr<vk_context_struct> vk_context_ref;
1417
1418struct ggml_vk_garbage_collector {
1419 std::vector<vk_semaphore> tl_semaphores;
1420 std::vector<vk_semaphore> semaphores;
1421 std::vector<vk::Event> events;
1422 std::vector<vk_context> contexts;
1423};
1424
1425static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx, vk_context subctx);
1426static void ggml_vk_load_shaders(vk_device& device);
1427static void ggml_pipeline_allocate_descriptor_sets(ggml_backend_vk_context * ctx);
1428
1429#if defined(GGML_VULKAN_MEMORY_DEBUG) || defined(GGML_VULKAN_DEBUG)
1430#define VK_LOG_MEMORY(msg) std::cerr << "ggml_vulkan memory: " << msg << std::endl
1431
1432static std::string format_size(size_t size) {
1433 const size_t kib = 1024;
1434 const size_t mib = kib * 1024;
1435 const size_t gib = mib * 1024;
1436
1437 std::ostringstream oss;
1438 oss << std::fixed << std::setprecision(2);
1439
1440 if (size >= gib) {
1441 oss << static_cast<double>(size) / gib << " GiB";
1442 } else if (size >= mib) {
1443 oss << static_cast<double>(size) / mib << " MiB";
1444 } else if (size >= kib) {
1445 oss << static_cast<double>(size) / kib << " KiB";
1446 } else {
1447 oss << size << " B";
1448 }
1449
1450 return oss.str();
1451}
1452
1453class vk_memory_logger {
1454public:
1455 vk_memory_logger(): total_device(0), total_host(0) {}
1456 void log_allocation(vk_buffer_ref buf_ref, size_t size);
1457 void log_deallocation(vk_buffer_ref buf_ref);
1458
1459private:
1460 std::map<vk::Buffer, size_t> allocations; // Track allocations
1461 size_t total_device;
1462 size_t total_host;
1463};
1464#else
1465#define VK_LOG_MEMORY(msg) ((void) 0)
1466#endif // GGML_VULKAN_MEMORY_DEBUG
1467
1468class vk_perf_logger {
1469 public:
1470 void print_timings() {
1471 if (timings.empty()) {
1472 return;
1473 }
1474 uint64_t total_all_op_times = 0;
1475 std::cerr << "----------------\nVulkan Timings:" << std::endl;
1476 for (const auto & t : timings) {
1477 uint64_t total_op_times = 0;
1478 for (const auto & time : t.second) {
1479 total_op_times += time;
1480 }
1481 std::cerr << t.first << ": " << t.second.size() << " x " << (total_op_times / t.second.size() / 1000.0)
1482 << " us";
1483
1484 // If we have as many flops entries as timing entries for the op, then compute and log the flops/S.
1485 auto it = flops.find(x: t.first);
1486 if (it != flops.end() && (it->second).size() == t.second.size()) {
1487 uint64_t total_op_flops = 0;
1488 for (const auto & elem : it->second) {
1489 total_op_flops += elem;
1490 }
1491 std::cerr << " ("
1492 << (double(total_op_flops) / (1000.0 * 1000.0 * 1000.0)) /
1493 (double(total_op_times) / (1000.0 * 1000.0 * 1000.0))
1494 << " GFLOPS/s)";
1495 }
1496
1497 total_all_op_times += total_op_times;
1498
1499 std::cerr << std::endl;
1500 }
1501
1502 if (timings.size() > 0) {
1503 std::cerr << "Total time: " << total_all_op_times / 1000.0 << " us." << std::endl;
1504 }
1505
1506 timings.clear();
1507 flops.clear();
1508 }
1509
1510 void log_timing(const ggml_tensor * node, uint64_t time) {
1511 if (node->op == GGML_OP_UNARY) {
1512 timings[ggml_unary_op_name(op: ggml_get_unary_op(tensor: node))].push_back(x: time);
1513 return;
1514 }
1515 if (node->op == GGML_OP_MUL_MAT || node->op == GGML_OP_MUL_MAT_ID) {
1516 const uint64_t m = node->src[0]->ne[1];
1517 const uint64_t n = node->ne[1];
1518 const uint64_t k = node->src[1]->ne[0];
1519 const uint64_t batch = node->src[1]->ne[2] * node->src[1]->ne[3];
1520 std::string name = ggml_op_name(op: node->op);
1521 if ((node->op == GGML_OP_MUL_MAT && n <= mul_mat_vec_max_cols) ||
1522 (node->op == GGML_OP_MUL_MAT_ID && node->src[2]->ne[1] == 1)) {
1523 name += "_VEC";
1524 }
1525 name += " ";
1526 name += ggml_type_name(type: node->src[0]->type);
1527 name += " m=" + std::to_string(val: m) + " n=" + std::to_string(val: n) + " k=" + std::to_string(val: k);
1528 if (batch > 1) {
1529 name += " batch=" + std::to_string(val: batch);
1530 }
1531 timings[name].push_back(x: time);
1532 flops[name].push_back(x: m * n * (k + (k - 1)) * batch);
1533 return;
1534 }
1535 if (node->op == GGML_OP_CONV_2D || node->op == GGML_OP_CONV_TRANSPOSE_2D) {
1536 std::string name = ggml_op_name(op: node->op);
1537 ggml_tensor * knl = node->src[0];
1538 uint64_t OW = node->ne[0];
1539 uint64_t OH = node->ne[1];
1540 uint64_t N = node->ne[3];
1541 uint64_t Cout = node->ne[2];
1542 uint64_t KW = knl->ne[0];
1543 uint64_t KH = knl->ne[1];
1544 uint64_t Cin = node->src[1]->ne[2];
1545 // KxCRS @ CRSxNPQ = KxNPQ -> M=K, K=CRS, N=NPQ
1546 uint64_t size_M = Cout;
1547 uint64_t size_K = Cin * KW * KH;
1548 uint64_t size_N = N * OW * OH;
1549 uint64_t n_flops = size_M * size_N * (size_K + (size_K - 1));
1550 name += " M=Cout=" + std::to_string(val: size_M) + ", K=Cin*KW*KH=" + std::to_string(val: size_K) +
1551 ", N=N*OW*OH=" + std::to_string(val: size_N);
1552 flops[name].push_back(x: n_flops);
1553 timings[name].push_back(x: time);
1554 return;
1555 }
1556 if (node->op == GGML_OP_RMS_NORM) {
1557 std::string name = ggml_op_name(op: node->op);
1558 name += "(" + std::to_string(val: node->ne[0]) + "," + std::to_string(val: node->ne[1]) + "," + std::to_string(val: node->ne[2]) + "," + std::to_string(val: node->ne[3]) + ")";
1559 timings[name].push_back(x: time);
1560 return;
1561 }
1562 timings[ggml_op_name(op: node->op)].push_back(x: time);
1563 }
1564 private:
1565 std::map<std::string, std::vector<uint64_t>> timings;
1566 std::map<std::string, std::vector<uint64_t>> flops;
1567};
1568
1569struct ggml_backend_vk_context {
1570 std::string name;
1571
1572 vk_device device;
1573
1574 size_t semaphore_idx, event_idx;
1575 ggml_vk_garbage_collector gc;
1576 size_t prealloc_size_x, prealloc_size_y, prealloc_size_split_k, prealloc_size_add_rms_partials, prealloc_size_add_rms_partials_offset;
1577 vk_buffer prealloc_x, prealloc_y, prealloc_split_k, prealloc_add_rms_partials;
1578 vk::Fence fence, almost_ready_fence;
1579 bool almost_ready_fence_pending {};
1580 // Set before op_add and unset after op_rms_norm to indicate that the add should
1581 // write partial sums to accumulate the square of the vector components
1582 bool do_add_rms_partials_offset_calculation;
1583 bool do_add_rms_partials;
1584
1585 uint64_t last_total_mul_mat_bytes {};
1586
1587 // Cache most recent tensor that was converted into prealloc_y, and what pipeline it used to convert.
1588 vk_pipeline_struct * prealloc_y_last_pipeline_used {};
1589 const ggml_tensor * prealloc_y_last_tensor_used {};
1590
1591 // Track which nodes have been used since the last sync, and whether they were written to
1592 std::vector<const ggml_tensor *> unsynced_nodes_written;
1593 std::vector<const ggml_tensor *> unsynced_nodes_read;
1594 // Track which prealloc buffers have pending reads that need to be synchronized.
1595 // These are checked before writing to the buffer (and call ggml_vk_sync_buffers if set),
1596 // and set to true after the buffer contents are consumed.
1597 bool prealloc_x_need_sync, prealloc_y_need_sync, prealloc_split_k_need_sync;
1598
1599 vk_context_ref compute_ctx;
1600 vk_context_ref transfer_ctx;
1601
1602 std::vector<vk_context_ref> tensor_ctxs;
1603
1604 std::vector<vk::DescriptorPool> descriptor_pools;
1605 std::vector<vk::DescriptorSet> descriptor_sets;
1606 uint32_t descriptor_set_idx {};
1607 uint32_t pipeline_descriptor_set_requirements {};
1608
1609 vk_command_pool compute_cmd_pool;
1610 vk_command_pool transfer_cmd_pool;
1611
1612 // number of additional consecutive nodes that are being fused with the
1613 // node currently being processed
1614 int num_additional_fused_ops {};
1615 // Bitmask of which fused ops need to write an intermediate value to memory.
1616 // Bit 'i' means nodes[start_of_fusion + i] writes to memory.
1617 // If there's no fusion, bit 0 is still set.
1618 int fused_ops_write_mask {};
1619};
1620
1621static void * const vk_ptr_base = (void *)(uintptr_t) 0x1000; // NOLINT
1622
1623static uint64_t vk_tensor_offset(const ggml_tensor * tensor) {
1624 if (tensor->view_src) {
1625 return (uint8_t *) tensor->view_src->data - (uint8_t *) vk_ptr_base;
1626 }
1627 return (uint8_t *) tensor->data - (uint8_t *) vk_ptr_base;
1628}
1629
1630struct ggml_backend_vk_buffer_context {
1631 vk_device_ref device;
1632 vk_buffer dev_buffer;
1633 std::string name;
1634
1635 ggml_backend_vk_buffer_context(vk_device_ref device, vk_buffer&& dev_buffer, std::string& name) :
1636 device(device),
1637 dev_buffer(dev_buffer),
1638 name(name) {
1639 }
1640
1641 ~ggml_backend_vk_buffer_context() {
1642 ggml_vk_destroy_buffer(buf&: dev_buffer);
1643 }
1644};
1645
1646#ifdef GGML_VULKAN_MEMORY_DEBUG
1647static std::mutex log_mutex;
1648
1649void vk_memory_logger::log_allocation(vk_buffer_ref buf_ref, size_t size) {
1650 std::lock_guard<std::mutex> guard(log_mutex);
1651 vk_buffer buf = buf_ref.lock();
1652 const bool device = bool(buf->memory_property_flags & vk::MemoryPropertyFlagBits::eDeviceLocal);
1653 const std::string type = device ? "device" : "host";
1654 allocations[buf->buffer] = size;
1655 total_device += device ? size : 0;
1656 total_host += device ? 0 : size;
1657 VK_LOG_MEMORY(buf->device->name << ": +" << format_size(size) << " " << type << " at " << buf->buffer << ". Total device: " << format_size(total_device) << ", total host: " << format_size(total_host));
1658}
1659
1660void vk_memory_logger::log_deallocation(vk_buffer_ref buf_ref) {
1661 if (buf_ref.expired() || buf_ref.lock()->size == 0) {
1662 return;
1663 }
1664
1665 std::lock_guard<std::mutex> guard(log_mutex);
1666 vk_buffer buf = buf_ref.lock();
1667 const bool device = bool(buf->memory_property_flags & vk::MemoryPropertyFlagBits::eDeviceLocal);
1668 std::string type = device ? "device" : "host";
1669 auto it = allocations.find(buf->buffer);
1670 total_device -= device ? it->second : 0;
1671 total_host -= device ? 0 : it->second;
1672 if (it != allocations.end()) {
1673 VK_LOG_MEMORY(buf->device->name << ": -" << format_size(it->second) << " " << type << " at " << buf->buffer << ". Total device: " << format_size(total_device) << ", total host: " << format_size(total_host));
1674 allocations.erase(it);
1675 } else {
1676 VK_LOG_MEMORY("ERROR " << buf->device->name << ": Attempted to deallocate unknown " << type << " memory at " << buf->buffer);
1677 }
1678}
1679#endif // GGML_VULKAN_MEMORY_DEBUG
1680
1681struct vk_instance_t {
1682 vk::Instance instance;
1683
1684 bool debug_utils_support = false; // VK_EXT_debug_utils enabled
1685 PFN_vkSetDebugUtilsObjectNameEXT pfn_vkSetDebugUtilsObjectNameEXT = {};
1686 PFN_vkQueueBeginDebugUtilsLabelEXT pfn_vkQueueBeginDebugUtilsLabelEXT = {};
1687 PFN_vkQueueEndDebugUtilsLabelEXT pfn_vkQueueEndDebugUtilsLabelEXT = {};
1688 PFN_vkCmdBeginDebugUtilsLabelEXT pfn_vkCmdBeginDebugUtilsLabelEXT = {};
1689 PFN_vkCmdEndDebugUtilsLabelEXT pfn_vkCmdEndDebugUtilsLabelEXT = {};
1690 PFN_vkCmdInsertDebugUtilsLabelEXT pfn_vkCmdInsertDebugUtilsLabelEXT = {};
1691
1692 std::vector<size_t> device_indices;
1693 std::vector<bool> device_supports_membudget;
1694 vk_device devices[GGML_VK_MAX_DEVICES];
1695};
1696
1697static bool vk_instance_initialized = false;
1698static vk_instance_t vk_instance;
1699
1700static bool vk_perf_logger_enabled = false;
1701
1702#ifdef GGML_VULKAN_CHECK_RESULTS
1703static size_t vk_skip_checks;
1704static size_t vk_output_tensor;
1705
1706static void ggml_vk_print_tensor(const ggml_tensor * tensor, const char * name);
1707static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx);
1708static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx);
1709#endif
1710
1711typedef void (*ggml_vk_func_t)(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
1712
1713static void ggml_backend_vk_free(ggml_backend_t backend);
1714
1715static VkDeviceSize ggml_vk_get_max_buffer_range(const ggml_backend_vk_context * ctx, const vk_buffer &buf, const VkDeviceSize offset) {
1716 const VkDeviceSize range = std::min(a: VkDeviceSize{buf->size - offset},
1717 b: VkDeviceSize{ctx->device->properties.limits.maxStorageBufferRange});
1718 return range;
1719}
1720
1721// Wait for ctx->fence to be signaled.
1722static void ggml_vk_wait_for_fence(ggml_backend_vk_context * ctx) {
1723 // Use waitForFences while most of the graph executes. Hopefully the CPU can sleep
1724 // during this wait.
1725 if (ctx->almost_ready_fence_pending) {
1726 VK_CHECK(ctx->device->device.waitForFences({ ctx->almost_ready_fence }, true, UINT64_MAX), "almost_ready_fence");
1727 ctx->device->device.resetFences(fences: { ctx->almost_ready_fence });
1728 ctx->almost_ready_fence_pending = false;
1729 }
1730
1731 // Spin (w/pause) waiting for the graph to finish executing.
1732 vk::Result result;
1733 while ((result = ctx->device->device.getFenceStatus(fence: ctx->fence)) != vk::Result::eSuccess) {
1734 if (result != vk::Result::eNotReady) {
1735 fprintf(stderr, format: "ggml_vulkan: error %s at %s:%d\n", to_string(value: result).c_str(), __FILE__, __LINE__);
1736 exit(status: 1);
1737 }
1738 for (uint32_t i = 0; i < 100; ++i) {
1739 YIELD();
1740 YIELD();
1741 YIELD();
1742 YIELD();
1743 YIELD();
1744 YIELD();
1745 YIELD();
1746 YIELD();
1747 YIELD();
1748 YIELD();
1749 }
1750 }
1751 ctx->device->device.resetFences(fences: { ctx->fence });
1752}
1753
1754// variables to track number of compiles in progress
1755static uint32_t compile_count = 0;
1756static std::mutex compile_count_mutex;
1757static std::condition_variable compile_count_cond;
1758
1759static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipeline, size_t spv_size, const void* spv_data, const std::string entrypoint,
1760 uint32_t parameter_count, std::array<uint32_t, 3> wg_denoms, std::vector<uint32_t> specialization_constants,
1761 bool disable_robustness, bool require_full_subgroups, uint32_t required_subgroup_size) {
1762 VK_LOG_DEBUG("ggml_vk_create_pipeline(" << device->name << ", " << pipeline->name << ", " << entrypoint << ", " << parameter_count <<
1763 ", (" << wg_denoms[0] << "," << wg_denoms[1] << "," << wg_denoms[2] << "), specialization_constants, " <<
1764 disable_robustness << ", " << require_full_subgroups << ", " << required_subgroup_size << ")");
1765 GGML_ASSERT(parameter_count > 0);
1766 GGML_ASSERT(parameter_count <= MAX_PARAMETER_COUNT);
1767 GGML_ASSERT(wg_denoms[0] > 0 && wg_denoms[1] > 0 && wg_denoms[2] > 0); // NOLINT
1768
1769 vk::ShaderModuleCreateInfo shader_module_create_info({}, spv_size, reinterpret_cast<const uint32_t *>(spv_data));
1770 pipeline->shader_module = device->device.createShaderModule(createInfo: shader_module_create_info);
1771
1772 vk::PushConstantRange pcr(
1773 vk::ShaderStageFlagBits::eCompute,
1774 0,
1775 pipeline->push_constant_size
1776 );
1777
1778 vk::PipelineLayoutCreateInfo pipeline_layout_create_info(vk::PipelineLayoutCreateFlags(), device->dsl, pcr);
1779 pipeline->layout = device->device.createPipelineLayout(createInfo: pipeline_layout_create_info);
1780
1781 std::vector<vk::SpecializationMapEntry> specialization_entries(specialization_constants.size());
1782
1783 for (size_t i = 0; i < specialization_constants.size(); i++) {
1784 specialization_entries[i].constantID = i;
1785 specialization_entries[i].offset = i * sizeof(uint32_t);
1786 specialization_entries[i].size = sizeof(uint32_t);
1787 }
1788
1789 vk::SpecializationInfo specialization_info(
1790 specialization_entries.size(),
1791 specialization_entries.data(),
1792 specialization_constants.size() * sizeof(uint32_t),
1793 specialization_constants.data()
1794 );
1795
1796 vk::PipelineShaderStageCreateFlags pipeline_shader_stage_create_flags{};
1797
1798 if (device->subgroup_require_full_support && require_full_subgroups) {
1799 pipeline_shader_stage_create_flags |= vk::PipelineShaderStageCreateFlagBits::eRequireFullSubgroupsEXT;
1800 }
1801
1802 vk::PipelineShaderStageCreateInfo pipeline_shader_create_info(
1803 pipeline_shader_stage_create_flags,
1804 vk::ShaderStageFlagBits::eCompute,
1805 pipeline->shader_module,
1806 entrypoint.c_str(),
1807 &specialization_info);
1808
1809 vk::PipelineShaderStageRequiredSubgroupSizeCreateInfoEXT pipeline_shader_stage_required_subgroup_size_create_info;
1810 pipeline_shader_stage_required_subgroup_size_create_info.requiredSubgroupSize = required_subgroup_size;
1811 if (device->subgroup_size_control && required_subgroup_size > 0) {
1812 GGML_ASSERT(device->subgroup_min_size <= required_subgroup_size && required_subgroup_size <= device->subgroup_max_size);
1813 pipeline_shader_create_info.setPNext(&pipeline_shader_stage_required_subgroup_size_create_info);
1814 }
1815
1816 vk::ComputePipelineCreateInfo compute_pipeline_create_info(
1817 device->pipeline_executable_properties_support ?
1818 vk::PipelineCreateFlagBits::eCaptureStatisticsKHR :
1819 vk::PipelineCreateFlags{},
1820 pipeline_shader_create_info,
1821 pipeline->layout);
1822
1823 vk::PipelineRobustnessCreateInfoEXT rci;
1824
1825 if (device->pipeline_robustness && disable_robustness) {
1826 rci.storageBuffers = vk::PipelineRobustnessBufferBehaviorEXT::eDisabled;
1827 rci.uniformBuffers = vk::PipelineRobustnessBufferBehaviorEXT::eDisabled;
1828 compute_pipeline_create_info.setPNext(&rci);
1829 }
1830
1831 try {
1832 pipeline->pipeline = device->device.createComputePipeline(VK_NULL_HANDLE, createInfo: compute_pipeline_create_info).value;
1833 } catch (const vk::SystemError& e) {
1834 std::cerr << "ggml_vulkan: Compute pipeline creation failed for " << pipeline->name << std::endl;
1835 std::cerr << "ggml_vulkan: " << e.what() << std::endl;
1836 throw e;
1837 }
1838 pipeline->compiled = true;
1839
1840 if (vk_instance.debug_utils_support) {
1841 vk::DebugUtilsObjectNameInfoEXT duoni;
1842 duoni.objectType = vk::ObjectType::ePipeline;
1843 duoni.pObjectName = pipeline->name.c_str();
1844 duoni.objectHandle = /*reinterpret_cast*/(uint64_t)(static_cast<VkPipeline>(pipeline->pipeline));
1845 vk_instance.pfn_vkSetDebugUtilsObjectNameEXT(device->device, &static_cast<VkDebugUtilsObjectNameInfoEXT &>(duoni));
1846 }
1847
1848 if (device->pipeline_executable_properties_support) {
1849 vk::PipelineExecutableInfoKHR executableInfo;
1850 executableInfo.pipeline = pipeline->pipeline;
1851
1852 auto statistics = device->device.getPipelineExecutableStatisticsKHR(executableInfo);
1853 for (auto & s : statistics) {
1854 // "Register Count" is reported by NVIDIA drivers.
1855 if (strcmp(s1: s.name, s2: "Register Count") == 0) {
1856 VK_LOG_DEBUG(pipeline->name << " " << s.name << ": " << s.value.u64 << " registers");
1857 pipeline->register_count = (uint32_t)s.value.u64;
1858 }
1859 }
1860 }
1861
1862 device->all_pipelines.push_back(x: pipeline);
1863
1864 {
1865 std::lock_guard<std::mutex> guard(compile_count_mutex);
1866 assert(compile_count > 0);
1867 compile_count--;
1868 }
1869 compile_count_cond.notify_all();
1870}
1871
1872static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline) {
1873 VK_LOG_DEBUG("ggml_pipeline_destroy_pipeline(" << pipeline->name << ")");
1874 device.destroyPipelineLayout(pipelineLayout: pipeline->layout);
1875
1876 device.destroyShaderModule(shaderModule: pipeline->shader_module);
1877
1878 device.destroyPipeline(pipeline: pipeline->pipeline);
1879}
1880
1881static void ggml_pipeline_request_descriptor_sets(ggml_backend_vk_context *ctx, vk_pipeline& pipeline, uint32_t n) {
1882 VK_LOG_DEBUG("ggml_pipeline_request_descriptor_sets(" << pipeline->name << ", " << n << ")");
1883 ctx->pipeline_descriptor_set_requirements += n;
1884 if (!pipeline->compiled) {
1885 pipeline->needed = true;
1886 ggml_vk_load_shaders(device&: ctx->device);
1887 }
1888 ggml_pipeline_allocate_descriptor_sets(ctx);
1889}
1890
1891static void ggml_pipeline_allocate_descriptor_sets(ggml_backend_vk_context * ctx) {
1892
1893 if (ctx->descriptor_sets.size() >= ctx->pipeline_descriptor_set_requirements) {
1894 // Enough descriptors are available
1895 return;
1896 }
1897
1898 vk_device& device = ctx->device;
1899
1900 // Grow by 50% to avoid frequent allocations
1901 uint32_t needed = std::max(a: 3 * ctx->descriptor_sets.size() / 2, b: size_t{ctx->pipeline_descriptor_set_requirements});
1902 uint32_t to_alloc = needed - ctx->descriptor_sets.size();
1903 uint32_t pool_remaining = VK_DEVICE_DESCRIPTOR_POOL_SIZE - ctx->descriptor_sets.size() % VK_DEVICE_DESCRIPTOR_POOL_SIZE;
1904 uint32_t pool_idx = ctx->descriptor_sets.size() / VK_DEVICE_DESCRIPTOR_POOL_SIZE;
1905
1906 while (to_alloc > 0) {
1907 const uint32_t alloc_count = std::min(a: pool_remaining, b: to_alloc);
1908 to_alloc -= alloc_count;
1909 pool_remaining = VK_DEVICE_DESCRIPTOR_POOL_SIZE;
1910
1911 if (pool_idx >= ctx->descriptor_pools.size()) {
1912 vk::DescriptorPoolSize descriptor_pool_size(vk::DescriptorType::eStorageBuffer, MAX_PARAMETER_COUNT * VK_DEVICE_DESCRIPTOR_POOL_SIZE);
1913 vk::DescriptorPoolCreateInfo descriptor_pool_create_info({}, VK_DEVICE_DESCRIPTOR_POOL_SIZE, descriptor_pool_size);
1914 ctx->descriptor_pools.push_back(x: device->device.createDescriptorPool(createInfo: descriptor_pool_create_info));
1915 }
1916
1917 std::vector<vk::DescriptorSetLayout> layouts(alloc_count);
1918 for (uint32_t i = 0; i < alloc_count; i++) {
1919 layouts[i] = device->dsl;
1920 }
1921 vk::DescriptorSetAllocateInfo descriptor_set_alloc_info(ctx->descriptor_pools[pool_idx], alloc_count, layouts.data());
1922 std::vector<vk::DescriptorSet> sets = device->device.allocateDescriptorSets(allocateInfo: descriptor_set_alloc_info);
1923 ctx->descriptor_sets.insert(position: ctx->descriptor_sets.end(), first: sets.begin(), last: sets.end());
1924
1925 pool_idx++;
1926 }
1927}
1928
1929static vk::CommandBuffer ggml_vk_create_cmd_buffer(vk_device& device, vk_command_pool& p) {
1930 VK_LOG_DEBUG("ggml_vk_create_cmd_buffer()");
1931
1932 if (p.cmd_buffers.size() > p.cmd_buffer_idx) {
1933 // Reuse command buffer
1934 return p.cmd_buffers[p.cmd_buffer_idx++];
1935 }
1936
1937 vk::CommandBufferAllocateInfo command_buffer_alloc_info(
1938 p.pool,
1939 vk::CommandBufferLevel::ePrimary,
1940 1);
1941 const std::vector<vk::CommandBuffer> cmd_buffers = device->device.allocateCommandBuffers(allocateInfo: command_buffer_alloc_info);
1942 auto buf = cmd_buffers.front();
1943
1944 p.cmd_buffers.push_back(x: buf);
1945 p.cmd_buffer_idx++;
1946
1947 return buf;
1948}
1949
1950static void ggml_vk_submit(vk_context& ctx, vk::Fence fence) {
1951 if (ctx->seqs.empty()) {
1952 if (fence) {
1953 std::lock_guard<std::mutex> guard(queue_mutex);
1954 ctx->p->q->queue.submit(submits: {}, fence);
1955 }
1956 return;
1957 }
1958 VK_LOG_DEBUG("ggml_vk_submit(" << ctx << ", " << fence << ")");
1959
1960 std::vector<std::vector<uint64_t>> tl_wait_vals;
1961 std::vector<std::vector<uint64_t>> tl_signal_vals;
1962 std::vector<std::vector<vk::Semaphore>> tl_wait_semaphores;
1963 std::vector<std::vector<vk::Semaphore>> tl_signal_semaphores;
1964 std::vector<vk::TimelineSemaphoreSubmitInfo> tl_submit_infos;
1965 std::vector<vk::SubmitInfo> submit_infos;
1966 int idx = -1;
1967 std::vector<std::vector<vk::PipelineStageFlags>> stage_flags;
1968
1969 size_t reserve = 0;
1970
1971 for (const auto& sequence : ctx->seqs) {
1972 reserve += sequence.size();
1973 }
1974
1975 // Pre-reserve vectors to prevent reallocation, which invalidates pointers
1976 tl_wait_semaphores.reserve(n: reserve);
1977 tl_wait_vals.reserve(n: reserve);
1978 tl_signal_semaphores.reserve(n: reserve);
1979 tl_signal_vals.reserve(n: reserve);
1980 tl_submit_infos.reserve(n: reserve);
1981 submit_infos.reserve(n: reserve);
1982 stage_flags.reserve(n: reserve);
1983
1984 for (const auto& sequence : ctx->seqs) {
1985 for (const auto& submission : sequence) {
1986 stage_flags.push_back(x: {});
1987 idx++;
1988 tl_wait_vals.push_back(x: {});
1989 tl_wait_semaphores.push_back(x: {});
1990 tl_signal_vals.push_back(x: {});
1991 tl_signal_semaphores.push_back(x: {});
1992 for (size_t i = 0; i < submission.wait_semaphores.size(); i++) {
1993 stage_flags[idx].push_back(x: ctx->p->q->stage_flags);
1994 tl_wait_vals[idx].push_back(x: submission.wait_semaphores[i].value);
1995 tl_wait_semaphores[idx].push_back(x: submission.wait_semaphores[i].s);
1996 }
1997 for (size_t i = 0; i < submission.signal_semaphores.size(); i++) {
1998 tl_signal_vals[idx].push_back(x: submission.signal_semaphores[i].value);
1999 tl_signal_semaphores[idx].push_back(x: submission.signal_semaphores[i].s);
2000 }
2001 tl_submit_infos.push_back(x: {
2002 (uint32_t) submission.wait_semaphores.size(),
2003 tl_wait_vals[idx].data(),
2004 (uint32_t) submission.signal_semaphores.size(),
2005 tl_signal_vals[idx].data(),
2006 });
2007 tl_submit_infos[idx].sType = vk::StructureType::eTimelineSemaphoreSubmitInfo;
2008 tl_submit_infos[idx].pNext = nullptr;
2009 vk::SubmitInfo si{
2010 (uint32_t) submission.wait_semaphores.size(),
2011 tl_wait_semaphores[idx].data(),
2012 stage_flags[idx].data(),
2013 1,
2014 &submission.buffer,
2015 (uint32_t) submission.signal_semaphores.size(),
2016 tl_signal_semaphores[idx].data(),
2017 };
2018 si.setPNext(&tl_submit_infos[idx]);
2019 submit_infos.push_back(x: si);
2020 }
2021 }
2022
2023 std::lock_guard<std::mutex> guard(queue_mutex);
2024 ctx->p->q->queue.submit(submits: submit_infos, fence);
2025
2026 ctx->seqs.clear();
2027}
2028
2029static uint32_t ggml_vk_find_queue_family_index(std::vector<vk::QueueFamilyProperties>& queue_family_props, const vk::QueueFlags& required, const vk::QueueFlags& avoid, int32_t compute_index, uint32_t min_num_queues) {
2030 VK_LOG_DEBUG("ggml_vk_find_queue_family_index()");
2031 const uint32_t qfsize = queue_family_props.size();
2032
2033 // Try with avoid preferences first
2034 for (uint32_t i = 0; i < qfsize; i++) {
2035 if (queue_family_props[i].queueCount >= min_num_queues && (compute_index < 0 || i != (uint32_t) compute_index) && queue_family_props[i].queueFlags & required && !(queue_family_props[i].queueFlags & avoid)) {
2036 return i;
2037 }
2038 }
2039
2040 // Fall back to only required
2041 for (size_t i = 0; i < qfsize; i++) {
2042 if (queue_family_props[i].queueCount >= min_num_queues && (compute_index < 0 || i != (uint32_t) compute_index) && queue_family_props[i].queueFlags & required) {
2043 return i;
2044 }
2045 }
2046
2047 // Fall back to reusing compute queue
2048 for (size_t i = 0; i < qfsize; i++) {
2049 if (queue_family_props[i].queueCount >= min_num_queues && queue_family_props[i].queueFlags & required) {
2050 return i;
2051 }
2052 }
2053
2054 // Fall back to ignoring min_num_queries
2055 for (size_t i = 0; i < qfsize; i++) {
2056 if (queue_family_props[i].queueFlags & required) {
2057 return i;
2058 }
2059 }
2060
2061 // All commands that are allowed on a queue that supports transfer operations are also allowed on a queue that supports either graphics or compute operations.
2062 // Thus, if the capabilities of a queue family include VK_QUEUE_GRAPHICS_BIT or VK_QUEUE_COMPUTE_BIT, then reporting the VK_QUEUE_TRANSFER_BIT capability separately for that queue family is optional.
2063 if (compute_index >= 0) {
2064 return compute_index;
2065 }
2066
2067 std::cerr << "ggml_vulkan: No suitable queue family index found." << std::endl;
2068
2069 for(auto &q_family : queue_family_props) {
2070 std::cerr << "Queue number: " + std::to_string(val: q_family.queueCount) << " flags: " + to_string(value: q_family.queueFlags) << std::endl;
2071 }
2072 abort();
2073}
2074
2075static void ggml_vk_create_queue(vk_device& device, vk_queue& q, uint32_t queue_family_index, uint32_t queue_index, vk::PipelineStageFlags&& stage_flags, bool transfer_only) {
2076 VK_LOG_DEBUG("ggml_vk_create_queue()");
2077 std::lock_guard<std::recursive_mutex> guard(device->mutex);
2078
2079 q.queue_family_index = queue_family_index;
2080 q.transfer_only = transfer_only;
2081
2082 q.cmd_pool.init(device, q_: &q);
2083
2084 q.queue = device->device.getQueue(queueFamilyIndex: queue_family_index, queueIndex: queue_index);
2085
2086 q.stage_flags = stage_flags;
2087}
2088
2089static vk_context ggml_vk_create_context(ggml_backend_vk_context * ctx, vk_command_pool& p) {
2090 vk_context result = std::make_shared<vk_context_struct>();
2091 VK_LOG_DEBUG("ggml_vk_create_context(" << result << ")");
2092 ctx->gc.contexts.emplace_back(args&: result);
2093 result->p = &p;
2094 return result;
2095}
2096
2097static vk_context ggml_vk_create_temporary_context(vk_command_pool& p) {
2098 vk_context result = std::make_shared<vk_context_struct>();
2099 VK_LOG_DEBUG("ggml_vk_create_temporary_context(" << result << ")");
2100 result->p = &p;
2101 return result;
2102}
2103
2104static vk_semaphore * ggml_vk_create_binary_semaphore(ggml_backend_vk_context * ctx) {
2105 VK_LOG_DEBUG("ggml_vk_create_timeline_semaphore()");
2106 vk::SemaphoreTypeCreateInfo tci{ vk::SemaphoreType::eBinary, 0 };
2107 vk::SemaphoreCreateInfo ci{};
2108 ci.setPNext(&tci);
2109 vk::Semaphore semaphore = ctx->device->device.createSemaphore(createInfo: ci);
2110 ctx->gc.semaphores.push_back(x: { .s: semaphore, .value: 0 });
2111 return &ctx->gc.semaphores[ctx->gc.semaphores.size() - 1];
2112}
2113
2114static vk_semaphore * ggml_vk_create_timeline_semaphore(ggml_backend_vk_context * ctx) {
2115 VK_LOG_DEBUG("ggml_vk_create_timeline_semaphore()");
2116 if (ctx->semaphore_idx >= ctx->gc.tl_semaphores.size()) {
2117 vk::SemaphoreTypeCreateInfo tci{ vk::SemaphoreType::eTimeline, 0 };
2118 vk::SemaphoreCreateInfo ci{};
2119 ci.setPNext(&tci);
2120 vk::Semaphore semaphore = ctx->device->device.createSemaphore(createInfo: ci);
2121 ctx->gc.tl_semaphores.push_back(x: { .s: semaphore, .value: 0 });
2122 }
2123 return &ctx->gc.tl_semaphores[ctx->semaphore_idx++];
2124}
2125
2126static vk::Event ggml_vk_create_event(ggml_backend_vk_context * ctx) {
2127 if (ctx->event_idx >= ctx->gc.events.size()) {
2128 ctx->gc.events.push_back(x: ctx->device->device.createEvent(createInfo: {}));
2129 }
2130 return ctx->gc.events[ctx->event_idx++];
2131}
2132
2133static void ggml_vk_command_pool_cleanup(vk_device& device, vk_command_pool& p) {
2134 VK_LOG_DEBUG("ggml_vk_command_pool_cleanup()");
2135
2136 // Requires command buffers to be done
2137 device->device.resetCommandPool(commandPool: p.pool);
2138 p.cmd_buffer_idx = 0;
2139}
2140
2141static void ggml_vk_queue_command_pools_cleanup(vk_device& device) {
2142 VK_LOG_DEBUG("ggml_vk_queue_command_pools_cleanup()");
2143
2144 // Arbitrary frequency to cleanup/reuse command buffers
2145 static constexpr uint32_t cleanup_frequency = 10;
2146
2147 if (device->compute_queue.cmd_pool.cmd_buffer_idx >= cleanup_frequency) {
2148 ggml_vk_command_pool_cleanup(device, p&: device->compute_queue.cmd_pool);
2149 }
2150 if (device->transfer_queue.cmd_pool.cmd_buffer_idx >= cleanup_frequency) {
2151 ggml_vk_command_pool_cleanup(device, p&: device->transfer_queue.cmd_pool);
2152 }
2153}
2154
2155
2156static uint32_t find_properties(const vk::PhysicalDeviceMemoryProperties* mem_props, vk::MemoryRequirements* mem_req, vk::MemoryPropertyFlags flags) {
2157 for (uint32_t i = 0; i < mem_props->memoryTypeCount; ++i) {
2158 vk::MemoryType memory_type = mem_props->memoryTypes[i];
2159 if ((mem_req->memoryTypeBits & ((uint64_t)1 << i)) &&
2160 (flags & memory_type.propertyFlags) == flags &&
2161 mem_props->memoryHeaps[memory_type.heapIndex].size >= mem_req->size) {
2162 return static_cast<int32_t>(i);
2163 }
2164 }
2165 return UINT32_MAX;
2166}
2167
2168static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std::initializer_list<vk::MemoryPropertyFlags> & req_flags_list) {
2169 VK_LOG_DEBUG("ggml_vk_create_buffer(" << device->name << ", " << size << ", " << to_string(req_flags_list.begin()[0]) << ", " << to_string(req_flags_list.begin()[req_flags_list.size()-1]) << ")");
2170 if (size > device->max_buffer_size) {
2171 throw vk::OutOfDeviceMemoryError("Requested buffer size exceeds device buffer size limit");
2172 }
2173
2174 vk_buffer buf = std::make_shared<vk_buffer_struct>();
2175
2176 if (size == 0) {
2177 buf->size = 0;
2178 return buf;
2179 }
2180
2181 vk::BufferUsageFlags usage_flags = vk::BufferUsageFlagBits::eStorageBuffer | vk::BufferUsageFlagBits::eTransferSrc | vk::BufferUsageFlagBits::eTransferDst;
2182 vk::MemoryAllocateFlags mem_flags {};
2183 if (device->buffer_device_address) {
2184 usage_flags |= vk::BufferUsageFlagBits::eShaderDeviceAddress;
2185 mem_flags |= vk::MemoryAllocateFlagBits::eDeviceAddress;
2186 }
2187
2188 vk::BufferCreateInfo buffer_create_info{
2189 vk::BufferCreateFlags(),
2190 size,
2191 usage_flags,
2192 vk::SharingMode::eExclusive,
2193 0,
2194 nullptr,
2195 };
2196
2197 buf->buffer = device->device.createBuffer(createInfo: buffer_create_info);
2198
2199 vk::MemoryRequirements mem_req = device->device.getBufferMemoryRequirements(buffer: buf->buffer);
2200
2201 vk::PhysicalDeviceMemoryProperties mem_props = device->physical_device.getMemoryProperties();
2202
2203 const vk::MemoryAllocateFlagsInfo mem_flags_info { mem_flags };
2204
2205 for (auto it = req_flags_list.begin(); it != req_flags_list.end(); it++) {
2206 const auto & req_flags = *it;
2207
2208 uint32_t memory_type_index = find_properties(mem_props: &mem_props, mem_req: &mem_req, flags: req_flags);
2209
2210 if (memory_type_index == UINT32_MAX) {
2211 continue;
2212 }
2213 buf->memory_property_flags = req_flags;
2214
2215 try {
2216 buf->device_memory = device->device.allocateMemory(allocateInfo: { mem_req.size, memory_type_index, &mem_flags_info });
2217 break;
2218 } catch (const vk::SystemError& e) {
2219 // loop and retry
2220 // during last attempt throw the exception
2221 if (it + 1 == req_flags_list.end()) {
2222 device->device.destroyBuffer(buffer: buf->buffer);
2223 throw e;
2224 }
2225 }
2226 }
2227
2228 if (!buf->device_memory) {
2229 device->device.destroyBuffer(buffer: buf->buffer);
2230 throw vk::OutOfDeviceMemoryError("No suitable memory type found");
2231 }
2232
2233 buf->ptr = nullptr;
2234
2235 if (buf->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) {
2236 buf->ptr = device->device.mapMemory(memory: buf->device_memory, offset: 0, VK_WHOLE_SIZE);
2237 }
2238
2239 device->device.bindBufferMemory(buffer: buf->buffer, memory: buf->device_memory, memoryOffset: 0);
2240
2241 buf->device = device;
2242 buf->size = size;
2243
2244 if (device->buffer_device_address) {
2245 const vk::BufferDeviceAddressInfo addressInfo(buf->buffer);
2246 buf->bda_addr = device->device.getBufferAddress(info: addressInfo);
2247 }
2248
2249#ifdef GGML_VULKAN_MEMORY_DEBUG
2250 device->memory_logger->log_allocation(buf, size);
2251#endif
2252
2253 return buf;
2254}
2255
2256static vk_buffer ggml_vk_create_buffer_check(vk_device& device, size_t size, vk::MemoryPropertyFlags req_flags, vk::MemoryPropertyFlags fallback_flags = vk::MemoryPropertyFlags(0)) {
2257 try {
2258 return ggml_vk_create_buffer(device, size, req_flags_list: {req_flags, fallback_flags});
2259 } catch (const vk::SystemError& e) {
2260 std::cerr << "ggml_vulkan: Memory allocation of size " << size << " failed." << std::endl;
2261 std::cerr << "ggml_vulkan: " << e.what() << std::endl;
2262 throw e;
2263 }
2264}
2265
2266static vk_buffer ggml_vk_create_buffer_device(vk_device& device, size_t size) {
2267 vk_buffer buf;
2268 try {
2269 if (device->prefer_host_memory) {
2270 buf = ggml_vk_create_buffer(device, size, req_flags_list: {vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent,
2271 vk::MemoryPropertyFlagBits::eDeviceLocal});
2272 } else if (device->uma) {
2273 // Fall back to host memory type
2274 buf = ggml_vk_create_buffer(device, size, req_flags_list: {vk::MemoryPropertyFlagBits::eDeviceLocal,
2275 vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent});
2276 } else if (device->disable_host_visible_vidmem) {
2277 if (device->allow_sysmem_fallback) {
2278 buf = ggml_vk_create_buffer(device, size, req_flags_list: {vk::MemoryPropertyFlagBits::eDeviceLocal,
2279 vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent});
2280 } else {
2281 buf = ggml_vk_create_buffer(device, size, req_flags_list: {vk::MemoryPropertyFlagBits::eDeviceLocal});
2282 }
2283 } else {
2284 // use rebar if available, otherwise fallback to device only visible memory
2285 if (device->allow_sysmem_fallback) {
2286 buf = ggml_vk_create_buffer(device, size, req_flags_list: {vk::MemoryPropertyFlagBits::eDeviceLocal | vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent,
2287 vk::MemoryPropertyFlagBits::eDeviceLocal,
2288 vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent});
2289 } else {
2290 buf = ggml_vk_create_buffer(device, size, req_flags_list: {vk::MemoryPropertyFlagBits::eDeviceLocal | vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent,
2291 vk::MemoryPropertyFlagBits::eDeviceLocal});
2292 }
2293 }
2294 } catch (const vk::SystemError& e) {
2295 std::cerr << "ggml_vulkan: Device memory allocation of size " << size << " failed." << std::endl;
2296 std::cerr << "ggml_vulkan: " << e.what() << std::endl;
2297 throw e;
2298 }
2299
2300 return buf;
2301}
2302
2303static void ggml_vk_destroy_buffer(vk_buffer& buf) {
2304 if (buf == nullptr) {
2305 return;
2306 }
2307
2308#ifdef GGML_VULKAN_MEMORY_DEBUG
2309 if (buf->device != nullptr) {
2310 buf->device->memory_logger->log_deallocation(buf);
2311 }
2312#endif
2313
2314 buf.reset();
2315}
2316
2317static vk_subbuffer ggml_vk_subbuffer(const ggml_backend_vk_context* ctx, const vk_buffer& buf, size_t offset = 0) {
2318 return { .buffer: buf, .offset: offset, .size: ggml_vk_get_max_buffer_range(ctx, buf, offset) };
2319}
2320
2321static void ggml_vk_sync_buffers(ggml_backend_vk_context* ctx, vk_context& subctx) {
2322 VK_LOG_DEBUG("ggml_vk_sync_buffers()");
2323
2324 const bool transfer_queue = subctx->p->q->transfer_only;
2325
2326 if (ctx) {
2327 ctx->prealloc_x_need_sync = ctx->prealloc_y_need_sync = ctx->prealloc_split_k_need_sync = false;
2328 }
2329
2330 subctx->s->buffer.pipelineBarrier(
2331 srcStageMask: subctx->p->q->stage_flags,
2332 dstStageMask: subctx->p->q->stage_flags,
2333 dependencyFlags: {},
2334 memoryBarriers: { {
2335 { !transfer_queue ? (vk::AccessFlagBits::eShaderRead | vk::AccessFlagBits::eShaderWrite | vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) : (vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) },
2336 { !transfer_queue ? (vk::AccessFlagBits::eShaderRead | vk::AccessFlagBits::eShaderWrite | vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) : (vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) }
2337 } },
2338 bufferMemoryBarriers: {},
2339 imageMemoryBarriers: {}
2340 );
2341}
2342
2343static void ggml_vk_wait_events(vk_context& ctx, std::vector<vk::Event>&& events) {
2344 VK_LOG_DEBUG("ggml_vk_wait_events()");
2345 if (events.empty()) {
2346 return;
2347 }
2348
2349 ctx->s->buffer.waitEvents(
2350 events,
2351 srcStageMask: ctx->p->q->stage_flags,
2352 dstStageMask: ctx->p->q->stage_flags,
2353 memoryBarriers: {},
2354 bufferMemoryBarriers: {},
2355 imageMemoryBarriers: {}
2356 );
2357}
2358
2359// number of rows/cols for flash attention shader
2360static constexpr uint32_t flash_attention_num_small_rows = 32;
2361static constexpr uint32_t scalar_flash_attention_num_small_rows = 1;
2362
2363static uint32_t get_fa_scalar_num_large_rows(uint32_t hsv) {
2364 if (hsv >= 192) {
2365 return 2;
2366 } else {
2367 return 8;
2368 }
2369}
2370
2371// The FA coopmat1 shader assumes 16x16x16 matrix multiply support.
2372// 128 threads split into four subgroups, each subgroup does 1/4
2373// of the Bc dimension.
2374static constexpr uint32_t coopmat1_flash_attention_num_large_rows = 16;
2375static constexpr uint32_t scalar_flash_attention_Bc = 64;
2376static constexpr uint32_t scalar_flash_attention_workgroup_size = 128;
2377
2378static uint32_t get_fa_num_small_rows(FaCodePath path) {
2379 if (path == FA_COOPMAT2) {
2380 return flash_attention_num_small_rows;
2381 } else {
2382 return scalar_flash_attention_num_small_rows;
2383 }
2384}
2385
2386static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) {
2387 GGML_UNUSED(clamp);
2388 GGML_UNUSED(hsv);
2389
2390 if (path == FA_SCALAR) {
2391 if (small_rows) {
2392 return {scalar_flash_attention_num_small_rows, 64};
2393 } else {
2394 if ((hsv | hsk) & 8) {
2395 // HSV/HSK not being a multiple of 16 makes D_split smaller, which makes cols_per_iter
2396 // larger, and Bc needs to be >= cols_per_thread. 64 is large enough, 32 is not.
2397 return {get_fa_scalar_num_large_rows(hsv), 64};
2398 } else {
2399 return {get_fa_scalar_num_large_rows(hsv), 32};
2400 }
2401 }
2402 }
2403
2404 if (path == FA_COOPMAT1) {
2405 if (small_rows) {
2406 return {scalar_flash_attention_num_small_rows, scalar_flash_attention_Bc};
2407 } else {
2408 return {coopmat1_flash_attention_num_large_rows, scalar_flash_attention_Bc};
2409 }
2410 }
2411
2412 // small rows, large cols
2413 if (small_rows) {
2414 return {get_fa_num_small_rows(path: FA_COOPMAT2), 32};
2415 }
2416
2417 // small cols to reduce register count
2418 if (ggml_is_quantized(type) || hsk >= 256 || hsv >= 256) {
2419 if (hsk >= 512 || hsv >= 512) {
2420 return {32, 32};
2421 } else {
2422 return {64, 32};
2423 }
2424 }
2425 return {64, 64};
2426}
2427
2428static uint32_t fa_align(FaCodePath path, uint32_t hsk, uint32_t hsv, ggml_type type, bool small_rows) {
2429 return fa_rows_cols(path, hsk, hsv, clamp: 0, type, small_rows)[1];
2430}
2431
2432static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector<uint32_t>& warptile, bool mul_mat_id, ggml_type src0_type) {
2433
2434 uint32_t lut_size = 0;
2435 switch (src0_type) {
2436 case GGML_TYPE_IQ1_S:
2437 case GGML_TYPE_IQ1_M:
2438 lut_size = 2*2048;
2439 break;
2440 case GGML_TYPE_IQ2_XXS:
2441 lut_size = 8*256;
2442 break;
2443 case GGML_TYPE_IQ2_XS:
2444 lut_size = 8*512;
2445 break;
2446 case GGML_TYPE_IQ2_S:
2447 lut_size = 8*1024;
2448 break;
2449 case GGML_TYPE_IQ3_XXS:
2450 lut_size = 4*256;
2451 break;
2452 case GGML_TYPE_IQ3_S:
2453 lut_size = 4*512;
2454 break;
2455 case GGML_TYPE_IQ4_NL:
2456 case GGML_TYPE_IQ4_XS:
2457 case GGML_TYPE_MXFP4:
2458 lut_size = 4*16;
2459 break;
2460 default:
2461 break;
2462 }
2463
2464 // Needs to be kept up to date on shader changes
2465 const uint32_t bank_conflict_offset = device->coopmat_support ? 8 : 1;
2466 const uint32_t type_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float);
2467 const uint32_t warps = warptile[0] / warptile[10];
2468
2469 const uint32_t load_bufs = (warptile[1] + warptile[2]) * (warptile[3] + bank_conflict_offset) * type_size;
2470 const uint32_t mmid_row_ids = mul_mat_id ? (warptile[2] * 2 * sizeof(uint16_t)) : 0;
2471 const uint32_t coopmat_stage = device->coopmat_support ? warptile[7] * warptile[8] / warps * sizeof(float) : 0;
2472 const uint32_t ballots_sh = mul_mat_id ? (warps * 4 * sizeof(uint32_t)) : 0;
2473
2474 const uint32_t total_size = load_bufs + mmid_row_ids + coopmat_stage + lut_size + ballots_sh;
2475 const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
2476
2477 VK_LOG_DEBUG("ggml_vk_matmul_shmem_support(warptile=(" << warptile[0] << "," << warptile[1] << "," << warptile[2] << "), "
2478 "mul_mat_id=" << mul_mat_id << ", src0_type=" << ggml_type_name(src0_type) << ", supported=" << supported);
2479
2480 return supported;
2481}
2482
2483struct GpuPipelineConfig {
2484 // GPU architecture identifier.
2485 // Example: vk_device_architecture::AMD_GCN
2486 vk_device_architecture arch;
2487
2488 // Mapping of pipeline names to their specific subgroup sizes.
2489 // Example: {"soft_max_f32", 64}
2490 std::unordered_map<std::string, uint32_t> pipelines;
2491
2492 // Default subgroup size for this GPU.
2493 // Defaults to 0 if not explicitly provided.
2494 uint32_t default_subgroup_size = 0;
2495};
2496
2497// Pipeline configuration for RDNA1 GPUs.
2498static const std::unordered_map<std::string, uint32_t> rdna1_pipelines = {
2499 {"soft_max", 64}, {"im2col", 64},
2500 {"argmax", 64}, {"mul_mat_vec", 64},
2501 {"mul_mat_vec_f16", 32}, {"mul_mat_vec_f32_f16", 32}
2502};
2503
2504// Pipeline configuration for RDNA2 GPUs.
2505static const std::unordered_map<std::string, uint32_t> rdna2_pipelines = {
2506 {"soft_max", 64}, {"im2col", 64},
2507};
2508
2509static constexpr uint32_t RDNA_DEFAULT_SUBGROUP_SIZE = 32;
2510
2511// Define configurations for different GPUs.
2512static std::vector<GpuPipelineConfig> gpu_pipeline_configs = {
2513 {
2514 .arch: vk_device_architecture::AMD_RDNA1,
2515 .pipelines: {
2516 rdna1_pipelines,
2517 },
2518 .default_subgroup_size: RDNA_DEFAULT_SUBGROUP_SIZE
2519 },
2520 {
2521 .arch: vk_device_architecture::AMD_RDNA2,
2522 .pipelines: {
2523 rdna2_pipelines,
2524 },
2525 .default_subgroup_size: RDNA_DEFAULT_SUBGROUP_SIZE
2526 },
2527};
2528
2529static uint32_t get_subgroup_size(const std::string &pipeline_name, const vk_device_architecture &arch) {
2530 for (const auto &config : gpu_pipeline_configs) {
2531 if (config.arch == arch) {
2532 auto pipIt = config.pipelines.find(x: pipeline_name);
2533 if (pipIt != config.pipelines.end()) {
2534 return pipIt->second;
2535 }
2536 std::vector<std::pair<std::string, uint32_t>> sorted_pipelines(config.pipelines.begin(), config.pipelines.end());
2537 std::sort(first: sorted_pipelines.begin(), last: sorted_pipelines.end(),
2538 comp: [](const auto &a, const auto &b) { return a.first.size() > b.first.size(); });
2539 for (const auto &entry : sorted_pipelines) {
2540 if (pipeline_name.find(str: entry.first) != std::string::npos) {
2541 return entry.second;
2542 }
2543 }
2544 return config.default_subgroup_size;
2545 }
2546 }
2547 return 0; // If no matching configuration is found
2548}
2549
2550static void ggml_vk_load_shaders(vk_device& device) {
2551 VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")");
2552
2553 std::lock_guard<std::recursive_mutex> guard(device->mutex);
2554 // some shaders have a minimum subgroup size
2555 const uint32_t subgroup_size_8 = std::max(a: device->subgroup_size, b: 8u);
2556 const uint32_t subgroup_size_16 = std::max(a: device->subgroup_size, b: 16u);
2557 const uint32_t subgroup_size_32 = std::max(a: device->subgroup_size, b: 32u);
2558
2559 const uint32_t mul_mat_subgroup_size = (device->vendor_id == VK_VENDOR_ID_INTEL && device->subgroup_size_control) ? device->subgroup_min_size : device->subgroup_size;
2560 const uint32_t mul_mat_subgroup_size_8 = std::max(a: mul_mat_subgroup_size, b: 8u);
2561 const uint32_t mul_mat_subgroup_size_16 = std::max(a: mul_mat_subgroup_size, b: 16u);
2562 const uint32_t mul_mat_subgroup_size_32 = std::max(a: mul_mat_subgroup_size, b: 32u);
2563
2564 const bool subgroup_min_size_16 = (!device->subgroup_size_control && device->subgroup_size >= 16) ||
2565 (device->subgroup_size_control && device->subgroup_max_size >= 16);
2566
2567 // mulmat
2568 std::vector<uint32_t> l_warptile, m_warptile, s_warptile,
2569 l_warptile_id, m_warptile_id, s_warptile_id,
2570 l_warptile_mmq, m_warptile_mmq, s_warptile_mmq,
2571 l_warptile_mmq_int, m_warptile_mmq_int, s_warptile_mmq_int,
2572 l_warptile_mmq_int_k, m_warptile_mmq_int_k, s_warptile_mmq_int_k,
2573 l_warptile_mmq_k, m_warptile_mmq_k, s_warptile_mmq_k,
2574 l_warptile_mmqid, m_warptile_mmqid, s_warptile_mmqid,
2575 l_warptile_mmqid_int, m_warptile_mmqid_int, s_warptile_mmqid_int,
2576 l_warptile_mmqid_int_k, m_warptile_mmqid_int_k, s_warptile_mmqid_int_k;
2577 std::array<uint32_t, 3> l_wg_denoms, m_wg_denoms, s_wg_denoms,
2578 l_mmq_wg_denoms, m_mmq_wg_denoms, s_mmq_wg_denoms,
2579 l_mmq_wg_denoms_k, m_mmq_wg_denoms_k, s_mmq_wg_denoms_k,
2580 l_mmqid_wg_denoms, m_mmqid_wg_denoms, s_mmqid_wg_denoms;
2581
2582 uint32_t l_align, m_align, s_align;
2583 if (device->coopmat2) {
2584 // spec constants and tile sizes for non-quant matmul/matmul_id
2585 l_warptile = { 256, 128, 256, 64, 1 };
2586 m_warptile = { 256, 128, 128, 64, 0 };
2587 s_warptile = { 128, 64, 64, 64, 0 };
2588 l_wg_denoms = {128, 256, 1 };
2589 m_wg_denoms = {128, 128, 1 };
2590 s_wg_denoms = { 64, 64, 1 };
2591
2592 // spec constants and tile sizes for quant matmul (non-Qi_K)
2593 l_warptile_mmq = { 256, 128, 256, 64, 1 };
2594 m_warptile_mmq = { 256, 128, 128, 64, 1 };
2595 s_warptile_mmq = { 256, 32, 64, 128, 0 };
2596 l_mmq_wg_denoms = { 128, 256, 1 };
2597 m_mmq_wg_denoms = { 128, 128, 1 };
2598 s_mmq_wg_denoms = { 32, 64, 1 };
2599
2600 // spec constants and tile sizes for quant matmul (Qi_K)
2601 l_warptile_mmq_k = { 256, 128, 256, 64, 1 };
2602 m_warptile_mmq_k = { 256, 128, 128, 64, 1 };
2603 s_warptile_mmq_k = { 256, 32, 64, 128, 0 };
2604 l_mmq_wg_denoms_k = { 128, 256, 1 };
2605 m_mmq_wg_denoms_k = { 128, 128, 1 };
2606 s_mmq_wg_denoms_k = { 32, 64, 1 };
2607
2608 // spec constants and tile sizes for quant matmul_id
2609 l_warptile_mmqid = { 256, 128, 128, 16, 1, device->subgroup_size };
2610 m_warptile_mmqid = { 256, 128, 64, 16, 0, device->subgroup_size };
2611 s_warptile_mmqid = { 256, 128, 64, 16, 0, device->subgroup_size };
2612 l_mmqid_wg_denoms = { 128, 128, 1 };
2613 m_mmqid_wg_denoms = { 128, 64, 1 };
2614 s_mmqid_wg_denoms = { 128, 64, 1 };
2615
2616 l_align = 128;
2617 m_align = 64;
2618 s_align = 32;
2619 } else {
2620 // Matrix cores require different warp group sizes
2621 const uint32_t tm_l = device->coopmat_support ? device->coopmat_m : 4;
2622 const uint32_t tm_m = device->coopmat_support ? device->coopmat_m : 4;
2623 const uint32_t tm_s = device->coopmat_support ? device->coopmat_m : 2;
2624 const uint32_t tn_l = device->coopmat_support ? device->coopmat_n : 4;
2625 const uint32_t tn_m = device->coopmat_support ? device->coopmat_n : 2;
2626 const uint32_t tn_s = device->coopmat_support ? device->coopmat_n : 2;
2627 const uint32_t tk_l = device->coopmat_support ? device->coopmat_k : 1;
2628 const uint32_t tk_m = device->coopmat_support ? device->coopmat_k : 1;
2629 const uint32_t tk_s = device->coopmat_support ? device->coopmat_k : 1;
2630
2631 l_warptile = { 128, 128, 128, 16, subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, subgroup_size_8 };
2632 m_warptile = { 128, 64, 64, 16, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
2633 s_warptile = { subgroup_size_16, 32, 32, 16, 32, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 };
2634
2635 l_warptile_mmq = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, subgroup_size_8 };
2636 m_warptile_mmq = { 128, 64, 64, 32, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
2637 s_warptile_mmq = { subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 };
2638
2639 // Integer MMQ has a smaller shared memory profile, but heavier register use
2640 l_warptile_mmq_int = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 };
2641 m_warptile_mmq_int = { 128, 64, 64, 32, subgroup_size_8, 32, 2, 2, 2, 1, subgroup_size_8 };
2642 s_warptile_mmq_int = { subgroup_size_32, 32, 32, 32, 32, 32, 2, 2, 1, 1, subgroup_size_8 };
2643
2644 // K-quants use even more registers, mitigate by setting WMITER to 1
2645 l_warptile_mmq_int_k = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 1, 4, 4, 1, subgroup_size_8 };
2646 m_warptile_mmq_int_k = { 128, 64, 64, 32, subgroup_size_8, 32, 1, 2, 2, 1, subgroup_size_8 };
2647 s_warptile_mmq_int_k = { subgroup_size_32, 32, 32, 32, 32, 32, 1, 2, 1, 1, subgroup_size_8 };
2648
2649 l_warptile_id = { 128, 128, 128, 16, mul_mat_subgroup_size_16 * 2, 64, 2, tm_l, tn_l, tk_l, mul_mat_subgroup_size_16 };
2650 m_warptile_id = { 128, 64, 64, 16, mul_mat_subgroup_size_16, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_16 };
2651 s_warptile_id = { mul_mat_subgroup_size_16, 32, 32, 16, 32, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_16 };
2652
2653 l_warptile_mmqid = { 128, 128, 128, 32, mul_mat_subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, mul_mat_subgroup_size_8 };
2654 m_warptile_mmqid = { 128, 64, 64, 32, mul_mat_subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_8 };
2655 s_warptile_mmqid = { mul_mat_subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_8 };
2656
2657 l_warptile_mmqid_int = { 128, 128, 128, 32, mul_mat_subgroup_size_8 * 2, 64, 2, 4, 4, 1, mul_mat_subgroup_size_8 };
2658 m_warptile_mmqid_int = { 128, 64, 64, 32, mul_mat_subgroup_size_8, 32, 2, 2, 2, 1, mul_mat_subgroup_size_8 };
2659 s_warptile_mmqid_int = { mul_mat_subgroup_size_32, 32, 32, 32, 32, 32, 2, 2, 1, 1, mul_mat_subgroup_size_8 };
2660
2661 l_warptile_mmqid_int_k = { 128, 128, 128, 32, mul_mat_subgroup_size_16 * 2, 64, 1, 4, 4, 1, mul_mat_subgroup_size_16 };
2662 m_warptile_mmqid_int_k = { 128, 64, 64, 32, mul_mat_subgroup_size_16, 32, 1, 2, 2, 1, mul_mat_subgroup_size_16 };
2663 s_warptile_mmqid_int_k = { mul_mat_subgroup_size_32, 32, 32, 32, 32, 32, 1, 2, 1, 1, mul_mat_subgroup_size_16 };
2664
2665 // chip specific tuning
2666 if ((device->architecture == AMD_GCN) && (device->driver_id != vk::DriverId::eAmdProprietary)) {
2667 m_warptile_mmq = m_warptile_mmq_int = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 };
2668 m_warptile_mmqid = m_warptile_mmqid_int = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 };
2669 }
2670
2671 l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 };
2672 m_mmq_wg_denoms = m_wg_denoms = { 64, 64, 1 };
2673 s_mmq_wg_denoms = s_wg_denoms = { 32, 32, 1 };
2674 l_align = 128;
2675 m_align = 64;
2676 s_align = 32;
2677
2678 for (uint32_t i = 0; i < GGML_TYPE_COUNT; ++i) {
2679 ggml_type t = (ggml_type)i;
2680 // Disable medium and large matrix multiplication if not enough shared memory is available
2681 // Check mmq warptiles as the largest configuration
2682 // Throw an error if not enough for any matrix multiplication is available
2683 if (!ggml_vk_matmul_shmem_support(device, warptile: s_warptile_mmq, mul_mat_id: false, src0_type: t)) {
2684 std::cerr << "ggml_vulkan: Error: Shared memory size too small for matrix multiplication." << std::endl;
2685 throw std::runtime_error("Shared memory size too small for matrix multiplication.");
2686 } else if (!ggml_vk_matmul_shmem_support(device, warptile: m_warptile_mmq, mul_mat_id: false, src0_type: t)) {
2687 device->mul_mat_m[i] = false;
2688 device->mul_mat_l[i] = false;
2689 } else if (!ggml_vk_matmul_shmem_support(device, warptile: l_warptile_mmq, mul_mat_id: false, src0_type: t)) {
2690 device->mul_mat_l[i] = false;
2691 }
2692
2693 // Disable mul_mat_id if not enough shared memory is available
2694 if (!ggml_vk_matmul_shmem_support(device, warptile: s_warptile_mmqid, mul_mat_id: true, src0_type: t)) {
2695 device->mul_mat_id_s[i] = false;
2696 device->mul_mat_id_m[i] = false;
2697 device->mul_mat_id_l[i] = false;
2698 } else if (!ggml_vk_matmul_shmem_support(device, warptile: m_warptile_mmqid, mul_mat_id: true, src0_type: t)) {
2699 device->mul_mat_id_m[i] = false;
2700 device->mul_mat_id_l[i] = false;
2701 } else if (!ggml_vk_matmul_shmem_support(device, warptile: l_warptile_mmqid, mul_mat_id: true, src0_type: t)) {
2702 device->mul_mat_id_l[i] = false;
2703 }
2704 }
2705 }
2706
2707 if (!device->pipeline_matmul_f32) {
2708 device->pipeline_matmul_f32 = std::make_shared<vk_matmul_pipeline_struct>();
2709 }
2710 if (!device->pipeline_matmul_f32_f16) {
2711 device->pipeline_matmul_f32_f16 = std::make_shared<vk_matmul_pipeline_struct>();
2712 }
2713 if (!device->pipeline_matmul_id_f32) {
2714 device->pipeline_matmul_id_f32 = std::make_shared<vk_matmul_pipeline_struct>();
2715 }
2716 if (!device->pipeline_matmul_bf16) {
2717 device->pipeline_matmul_bf16 = std::make_shared<vk_matmul_pipeline_struct>();
2718 }
2719 if (!device->pipeline_matmul_id_bf16) {
2720 device->pipeline_matmul_id_bf16 = std::make_shared<vk_matmul_pipeline_struct>();
2721 }
2722
2723 std::vector<std::future<void>> compiles;
2724 auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const char *name, size_t spv_size, const void* spv_data, const char *entrypoint,
2725 uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, const std::vector<uint32_t>& specialization_constants,
2726 uint32_t align, bool disable_robustness = false, bool require_full_subgroups = false, uint32_t required_subgroup_size = 0) {
2727
2728 if (!require_full_subgroups && required_subgroup_size == 0) {
2729 required_subgroup_size = get_subgroup_size(pipeline_name: name, arch: device->architecture);
2730 }
2731
2732 if (!pipeline) {
2733 pipeline = std::make_shared<vk_pipeline_struct>();
2734 }
2735 if (!pipeline->initialized) {
2736 pipeline->name = name;
2737 pipeline->parameter_count = parameter_count;
2738 pipeline->push_constant_size = push_constant_size;
2739 pipeline->wg_denoms = wg_denoms;
2740 pipeline->align = align;
2741 pipeline->initialized = true;
2742 }
2743
2744 if (!pipeline->needed || pipeline->compiled) {
2745 return;
2746 }
2747 // TODO: We're no longer benefitting from the async compiles (shaders are
2748 // compiled individually, as needed) and this complexity can be removed.
2749 {
2750 // wait until fewer than N compiles are in progress
2751 uint32_t N = std::max(a: 1u, b: std::thread::hardware_concurrency());
2752 std::unique_lock<std::mutex> guard(compile_count_mutex);
2753 while (compile_count >= N) {
2754 compile_count_cond.wait(lock&: guard);
2755 }
2756 compile_count++;
2757 }
2758
2759 compiles.push_back(x: std::async(fn&: ggml_vk_create_pipeline_func, args: std::ref(t&: device), args: std::ref(t&: pipeline), args&: spv_size, args&: spv_data, args&: entrypoint,
2760 args&: parameter_count, args&: wg_denoms, args: specialization_constants, args&: disable_robustness, args&: require_full_subgroups, args&: required_subgroup_size));
2761 };
2762
2763 auto const &ggml_vk_create_pipeline2 = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const char *entrypoint,
2764 uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, const std::vector<uint32_t>& specialization_constants,
2765 uint32_t align, bool disable_robustness = false, bool require_full_subgroups = false, uint32_t required_subgroup_size = 0) {
2766 return ggml_vk_create_pipeline(device, pipeline, name.c_str(), spv_size, spv_data, entrypoint,
2767 parameter_count, push_constant_size, wg_denoms, specialization_constants,
2768 align, disable_robustness, require_full_subgroups, required_subgroup_size);
2769 };
2770
2771 auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) -> std::array<uint32_t, 3> {
2772 return {fa_rows_cols(path, hsk, hsv, clamp, type, small_rows)[0], 1, 1};
2773 };
2774
2775 auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector<uint32_t> {
2776 // For large number of rows, 128 invocations seems to work best.
2777 // For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we
2778 // can't use 256 for D==80.
2779 // For scalar, use 128 (arbitrary)
2780 // The same D_split value is used for both HSK and HSV, so just base it on the union of the LSBs.
2781 const uint32_t D = (hsk|hsv);
2782 uint32_t wg_size = (path == FA_SCALAR || path == FA_COOPMAT1)
2783 ? scalar_flash_attention_workgroup_size
2784 : ((small_rows && (D % 32) == 0) ? 256 : 128);
2785 auto rows_cols = fa_rows_cols(path, hsk, hsv, clamp, type, small_rows);
2786
2787 // D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it.
2788 // D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader.
2789 const uint32_t D_lsb = D ^ (D & (D-1));
2790 uint32_t D_split = std::min(a: std::min(a: device->subgroup_size, b: 8u), b: D_lsb / 4);
2791
2792 return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split};
2793 };
2794
2795#define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \
2796 for (auto &fa : device->pipeline_flash_attn_f32_f16[TYPE]) { \
2797 uint32_t HSK = fa.first.HSK; \
2798 uint32_t HSV = fa.first.HSV; \
2799 bool small_rows = fa.first.small_rows; \
2800 FaCodePath path = fa.first.path; \
2801 bool aligned = fa.first.aligned; \
2802 bool f32acc = fa.first.f32acc; \
2803 if (path == FAPATH) { \
2804 if (aligned) { \
2805 if (f32acc) { \
2806 ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_align(FAPATH,HSK,HSV,TYPE,small_rows), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
2807 } else { \
2808 ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_align(FAPATH,HSK,HSV,TYPE,small_rows), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
2809 } \
2810 } else { \
2811 if (f32acc) { \
2812 ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
2813 } else { \
2814 ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
2815 } \
2816 } \
2817 } \
2818 }
2819
2820 CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, )
2821 CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, )
2822 CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
2823 CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, )
2824#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
2825 if (device->coopmat1_fa_support) {
2826 CREATE_FA(GGML_TYPE_F32, f32, FA_COOPMAT1, _cm1)
2827 CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT1, _cm1)
2828 CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT1, _cm1)
2829 CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT1, _cm1)
2830 }
2831#endif
2832#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
2833 if (device->coopmat2) {
2834 CREATE_FA(GGML_TYPE_F32, f32, FA_COOPMAT2, _cm2)
2835 CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT2, _cm2)
2836 CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT2, _cm2)
2837 CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_COOPMAT2, _cm2)
2838 CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_COOPMAT2, _cm2)
2839 CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_COOPMAT2, _cm2)
2840 CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT2, _cm2)
2841 CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_COOPMAT2, _cm2)
2842 }
2843#endif
2844#undef CREATE_FA
2845
2846#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
2847 if (device->coopmat2) {
2848
2849 // Create 6 variants, {s,m,l}x{unaligned,aligned}
2850#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
2851 ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
2852 ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
2853 ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
2854 ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align); \
2855 ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align); \
2856 ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
2857
2858 // Create 2 variants, {f16,f32} accumulator
2859#define CREATE_MM2(PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
2860 CREATE_MM(PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
2861 CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
2862
2863 CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3)
2864#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
2865 if (device->coopmat_bf16_support) {
2866 CREATE_MM(pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3)
2867 }
2868#endif
2869 CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_0], matmul_q4_0_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
2870 CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_1], matmul_q4_1_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
2871 CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_0], matmul_q5_0_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
2872 CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_1], matmul_q5_1_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
2873 CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q8_0], matmul_q8_0_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
2874 CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q2_K], matmul_q2_k_f16, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
2875 CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q3_K], matmul_q3_k_f16, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
2876 CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_K], matmul_q4_k_f16, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
2877 CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_K], matmul_q5_k_f16, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
2878 CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q6_K], matmul_q6_k_f16, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
2879 CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ1_S], matmul_iq1_s_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
2880 CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ1_M], matmul_iq1_m_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
2881 CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_XXS], matmul_iq2_xxs_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
2882 CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_XS], matmul_iq2_xs_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
2883 CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_S], matmul_iq2_s_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
2884 CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ3_XXS], matmul_iq3_xxs_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
2885 CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ3_S], matmul_iq3_s_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
2886 CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
2887 CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
2888 CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_MXFP4], matmul_mxfp4_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
2889
2890 GGML_ASSERT(device->subgroup_ballot);
2891
2892 CREATE_MM2(pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
2893#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
2894 if (device->coopmat_bf16_support) {
2895 CREATE_MM(pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
2896 }
2897#endif
2898 CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
2899 CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
2900 CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
2901 CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
2902 CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
2903 CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
2904 CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
2905 CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
2906 CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
2907 CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
2908 CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_subgroup_iq1_s_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
2909 CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_subgroup_iq1_m_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
2910 CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_subgroup_iq2_xxs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
2911 CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_subgroup_iq2_xs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
2912 CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_subgroup_iq2_s_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
2913 CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_subgroup_iq3_xxs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
2914 CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_subgroup_iq3_s_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
2915 CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
2916 CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
2917 CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
2918#undef CREATE_MM
2919#undef CREATE_MM2
2920 } else
2921#endif // defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
2922#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
2923 if (device->coopmat_support) {
2924 // Create 6 variants, {s,m,l}x{unaligned,aligned}
2925#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
2926 if (device->mul_mat ## ID ## _l[TYPE]) \
2927 ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, true); \
2928 if (device->mul_mat ## ID ## _m[TYPE]) \
2929 ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, true); \
2930 if (device->mul_mat ## ID ## _s[TYPE]) \
2931 ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, true); \
2932 if (device->mul_mat ## ID ## _l[TYPE]) \
2933 ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, true); \
2934 if (device->mul_mat ## ID ## _m[TYPE]) \
2935 ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, true); \
2936 if (device->mul_mat ## ID ## _s[TYPE]) \
2937 ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, true); \
2938
2939 // Create 2 variants, {f16,f32} accumulator
2940#define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
2941 if (device->coopmat_acc_f16_support) { \
2942 CREATE_MM(TYPE, PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
2943 } \
2944 if (device->coopmat_acc_f32_support) { \
2945 CREATE_MM(TYPE, PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
2946 } \
2947
2948 CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
2949 CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
2950 CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
2951 CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
2952#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
2953 if (device->coopmat_bf16_support) {
2954 CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, )
2955 }
2956#endif
2957
2958 if (device->coopmat_acc_f16_support) {
2959 CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0], matmul_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2960 CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1], matmul_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2961 CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0], matmul_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2962 CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1], matmul_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2963 CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0], matmul_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2964
2965 CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K], matmul_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2966 CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K], matmul_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2967 CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K], matmul_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2968 CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K], matmul_q5_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2969 CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K], matmul_q6_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2970 CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S], matmul_iq1_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2971 CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M], matmul_iq1_m_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2972 CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS], matmul_iq2_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2973 CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS], matmul_iq2_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2974 CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S], matmul_iq2_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2975 CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS], matmul_iq3_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2976 CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S], matmul_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2977 CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2978 CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2979 CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4], matmul_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2980 } else {
2981 CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2982 CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2983 CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2984 CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2985 CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2986
2987 CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2988 CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2989 CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2990 CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2991 CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2992 CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S].f32acc, matmul_iq1_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2993 CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M].f32acc, matmul_iq1_m_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2994 CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f32acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2995 CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f32acc, matmul_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2996 CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f32acc, matmul_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2997 CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f32acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2998 CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2999 CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
3000 CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
3001 CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4].f32acc, matmul_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
3002 }
3003
3004 GGML_ASSERT(device->subgroup_ballot);
3005
3006 CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
3007 CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
3008 CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_subgroup_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
3009#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
3010 if (device->coopmat_bf16_support) {
3011 CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
3012 }
3013#endif
3014
3015 CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
3016 CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
3017 CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
3018 CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
3019 CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
3020 CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
3021 CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
3022 CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
3023 CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
3024 CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
3025 CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_subgroup_iq1_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
3026 CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_subgroup_iq1_m_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
3027 CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_subgroup_iq2_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
3028 CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_subgroup_iq2_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
3029 CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_subgroup_iq2_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
3030 CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_subgroup_iq3_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
3031 CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_subgroup_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
3032 CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
3033 CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
3034 CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
3035#undef CREATE_MM2
3036#undef CREATE_MM
3037 } else
3038#endif // defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
3039 if (device->fp16) {
3040 // Create 6 variants, {s,m,l}x{unaligned,aligned}
3041#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
3042 if (device->mul_mat ## ID ## _l[TYPE]) \
3043 ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
3044 if (device->mul_mat ## ID ## _m[TYPE]) \
3045 ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
3046 if (device->mul_mat ## ID ## _s[TYPE]) \
3047 ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
3048 if (device->mul_mat ## ID ## _l[TYPE]) \
3049 ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
3050 if (device->mul_mat ## ID ## _m[TYPE]) \
3051 ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
3052 if (device->mul_mat ## ID ## _s[TYPE]) \
3053 ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
3054
3055#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
3056 if (device->mul_mat ## ID ## _l[TYPE]) { \
3057 ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->l, #NAMELC "_l", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
3058 } \
3059 if (device->mul_mat ## ID ## _m[TYPE]) { \
3060 ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->m, #NAMELC "_m", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
3061 } \
3062 if (device->mul_mat ## ID ## _s[TYPE]) { \
3063 ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->s, #NAMELC "_s", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
3064 } \
3065
3066 // Create 2 variants, {f16,f32} accumulator
3067#define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
3068 CREATE_MM(TYPE, PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
3069 CREATE_MM(TYPE, PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
3070
3071 CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
3072 CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
3073 CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
3074 CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
3075
3076 CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
3077
3078 CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0], matmul_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3079 CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1], matmul_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3080 CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0], matmul_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3081 CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1], matmul_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3082 CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0], matmul_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3083
3084 CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K], matmul_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3085 CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K], matmul_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3086 CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K], matmul_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3087 CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K], matmul_q5_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3088 CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K], matmul_q6_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3089 CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S], matmul_iq1_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3090 CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M], matmul_iq1_m_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3091 CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS], matmul_iq2_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3092 CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS], matmul_iq2_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3093 CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S], matmul_iq2_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3094 CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS], matmul_iq3_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3095 CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S], matmul_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3096 CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3097 CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3098 CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4], matmul_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3099
3100#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
3101 if (device->integer_dot_product) {
3102 CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0], matmul_q4_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0);
3103 CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1], matmul_q4_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0);
3104 CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0], matmul_q5_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0);
3105 CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1], matmul_q5_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0);
3106 CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0], matmul_q8_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0);
3107
3108 CREATE_MMQ(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_MXFP4], matmul_mxfp4_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0);
3109
3110 CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q2_K], matmul_q2_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0);
3111 CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q3_K], matmul_q3_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0);
3112 CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_K], matmul_q4_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0);
3113 CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_K], matmul_q5_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0);
3114 CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q6_K], matmul_q6_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0);
3115 }
3116#endif
3117
3118 if (device->subgroup_ballot && device->subgroup_require_full_support && subgroup_min_size_16) {
3119 CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16);
3120 CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16);
3121 CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_subgroup_f16_f32, wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16);
3122 CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16);
3123
3124 CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3125 CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3126 CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3127 CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3128 CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3129 CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3130 CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3131 CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3132 CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3133 CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3134 CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_subgroup_iq1_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3135 CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_subgroup_iq1_m_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3136 CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_subgroup_iq2_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3137 CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_subgroup_iq2_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3138 CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_subgroup_iq2_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3139 CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_subgroup_iq3_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3140 CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_subgroup_iq3_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3141 CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3142 CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3143 CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3144
3145#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
3146 if (device->integer_dot_product) {
3147 CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3148 CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3149 CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3150 CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3151 CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3152
3153 CREATE_MMQ(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3154
3155 CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16);
3156 CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16);
3157 CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16);
3158 CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16);
3159 CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16);
3160 }
3161#endif
3162 } else {
3163 CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0);
3164 CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0);
3165 CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0);
3166 CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id, 0);
3167
3168 CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
3169 CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
3170 CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_q5_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
3171 CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_q5_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
3172 CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_q8_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
3173 CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_q2_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
3174 CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_q3_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
3175 CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_q4_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
3176 CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_q5_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
3177 CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_q6_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
3178 CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_iq1_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
3179 CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_iq1_m_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
3180 CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_iq2_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
3181 CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_iq2_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
3182 CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_iq2_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
3183 CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_iq3_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
3184 CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_iq3_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
3185 CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
3186 CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
3187 CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
3188
3189#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
3190 if (device->integer_dot_product) {
3191 CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_0], matmul_id_q4_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0);
3192 CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_1], matmul_id_q4_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0);
3193 CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_0], matmul_id_q5_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0);
3194 CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_1], matmul_id_q5_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0);
3195 CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q8_0], matmul_id_q8_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0);
3196
3197 CREATE_MMQ(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_MXFP4], matmul_id_mxfp4_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0);
3198
3199 CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q2_K], matmul_id_q2_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0);
3200 CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q3_K], matmul_id_q3_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0);
3201 CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_K], matmul_id_q4_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0);
3202 CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_K], matmul_id_q5_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0);
3203 CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q6_K], matmul_id_q6_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0);
3204 }
3205#endif
3206 }
3207#undef CREATE_MM2
3208#undef CREATE_MMQ
3209#undef CREATE_MM
3210 } else {
3211 // Create 6 variants, {s,m,l}x{unaligned,aligned}
3212#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
3213 if (device->mul_mat ## ID ## _l[TYPE]) \
3214 ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
3215 if (device->mul_mat ## ID ## _m[TYPE]) \
3216 ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
3217 if (device->mul_mat ## ID ## _s[TYPE]) \
3218 ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
3219 if (device->mul_mat ## ID ## _l[TYPE]) \
3220 ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
3221 if (device->mul_mat ## ID ## _m[TYPE]) \
3222 ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
3223 if (device->mul_mat ## ID ## _s[TYPE]) \
3224 ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
3225
3226#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
3227 if (device->mul_mat ## ID ## _l[TYPE]) \
3228 ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC "_l", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
3229 if (device->mul_mat ## ID ## _m[TYPE]) \
3230 ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC "_m", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
3231 if (device->mul_mat ## ID ## _s[TYPE]) \
3232 ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC "_s", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
3233
3234 CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
3235 CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
3236 CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
3237 CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16_f32.f32acc, matmul_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
3238
3239 CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
3240
3241 CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3242 CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3243 CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3244 CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3245 CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3246
3247 CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3248 CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3249 CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3250 CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3251 CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3252 CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S].f32acc, matmul_iq1_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3253 CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M].f32acc, matmul_iq1_m_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3254 CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f32acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3255 CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f32acc, matmul_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3256 CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f32acc, matmul_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3257 CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f32acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3258 CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3259 CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3260 CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3261 CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4].f32acc, matmul_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3262
3263#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
3264 if (device->integer_dot_product) {
3265 CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
3266 CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
3267 CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
3268 CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
3269 CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
3270
3271 CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
3272 CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
3273 CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
3274 CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
3275 CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
3276 }
3277#endif
3278
3279 if (device->subgroup_ballot && device->subgroup_require_full_support && subgroup_min_size_16) {
3280 CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16);
3281 CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_subgroup_f16, , wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16);
3282 CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_subgroup_f16_f32, , wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16);
3283 CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16);
3284
3285 CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_subgroup_q4_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3286 CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_subgroup_q4_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3287 CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_subgroup_q5_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3288 CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_subgroup_q5_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3289 CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_subgroup_q8_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3290 CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_subgroup_q2_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3291 CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_subgroup_q3_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3292 CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_subgroup_q4_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3293 CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_subgroup_q5_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3294 CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_subgroup_q6_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3295 CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f32acc, matmul_id_subgroup_iq1_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3296 CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f32acc, matmul_id_subgroup_iq1_m_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3297 CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f32acc, matmul_id_subgroup_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3298 CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f32acc, matmul_id_subgroup_iq2_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3299 CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc, matmul_id_subgroup_iq2_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3300 CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f32acc, matmul_id_subgroup_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3301 CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_subgroup_iq3_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3302 CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_subgroup_iq4_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3303 CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_subgroup_iq4_nl_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3304 CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc, matmul_id_subgroup_mxfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3305 } else {
3306 CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0);
3307 CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0);
3308 CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0);
3309 CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id, 0);
3310
3311 CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
3312 CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
3313 CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
3314 CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
3315 CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
3316 CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
3317 CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
3318 CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
3319 CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
3320 CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
3321 CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f32acc, matmul_id_iq1_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
3322 CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f32acc, matmul_id_iq1_m_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
3323 CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f32acc, matmul_id_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
3324 CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f32acc, matmul_id_iq2_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
3325 CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc, matmul_id_iq2_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
3326 CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f32acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
3327 CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
3328 CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
3329 CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
3330 CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc, matmul_id_mxfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
3331 }
3332 }
3333 // reusing CREATE_MM from the fp32 path
3334 if ((device->coopmat2 || device->coopmat_support)
3335#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
3336 && !device->coopmat_bf16_support
3337#endif
3338 ) {
3339 // use scalar tile sizes
3340 l_warptile = { 128, 128, 128, 16, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 };
3341 m_warptile = { 128, 64, 64, 16, subgroup_size_8, 32, 2, 4, 2, 1, subgroup_size_8 };
3342 s_warptile = { subgroup_size_16, 32, 32, 16, 32, 32, 2, 2, 2, 1, subgroup_size_8 };
3343
3344 l_wg_denoms = {128, 128, 1 };
3345 m_wg_denoms = { 64, 64, 1 };
3346 s_wg_denoms = { 32, 32, 1 };
3347
3348 CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
3349 CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id, 0);
3350 }
3351#undef CREATE_MM
3352
3353 // mul mat vec
3354
3355 // the number of rows computed per shader depends on GPU model and quant
3356 uint32_t rm_stdq = 1;
3357 uint32_t rm_kq = 2;
3358 if (device->vendor_id == VK_VENDOR_ID_AMD) {
3359 if (device->architecture == AMD_GCN) {
3360 rm_stdq = 2;
3361 rm_kq = 4;
3362 }
3363 } else if (device->vendor_id == VK_VENDOR_ID_INTEL)
3364 rm_stdq = 2;
3365 uint32_t rm_iq = 2 * rm_kq;
3366
3367 const bool use_subgroups = device->subgroup_arithmetic && device->architecture != vk_device_architecture::AMD_GCN;
3368 // Ensure a subgroup size >= 16 is available
3369 const bool use_subgroups16 = use_subgroups && subgroup_min_size_16;
3370
3371 const uint32_t subgroup_size = (device->vendor_id == VK_VENDOR_ID_INTEL && device->subgroup_size_control && device->subgroup_min_size <= 16 && device->subgroup_max_size >= 16) ? 16 : device->subgroup_size;
3372 const uint32_t subgroup_size16 = std::max(a: subgroup_size, b: 16u);
3373
3374 const uint32_t force_subgroup_size = use_subgroups ? subgroup_size : 0;
3375 const uint32_t force_subgroup_size16 = use_subgroups16 ? subgroup_size16 : 0;
3376
3377 for (uint32_t w = 0; w < DMMV_WG_SIZE_COUNT; ++w) {
3378 const uint32_t wg_size_subgroup = (w == DMMV_WG_SIZE_SUBGROUP) ? subgroup_size : (subgroup_size * 4);
3379 const uint32_t wg_size_subgroup16 = (w == DMMV_WG_SIZE_SUBGROUP) ? subgroup_size16 : (subgroup_size16 * 4);
3380
3381 const shader_reduction_mode reduc = (use_subgroups && w == DMMV_WG_SIZE_SUBGROUP) ? SHADER_REDUCTION_MODE_SUBGROUP :
3382 (use_subgroups && w == DMMV_WG_SIZE_LARGE) ? SHADER_REDUCTION_MODE_HYBRID :
3383 SHADER_REDUCTION_MODE_SHMEM;
3384
3385 const shader_reduction_mode reduc16 = (use_subgroups16 && w == DMMV_WG_SIZE_SUBGROUP) ? SHADER_REDUCTION_MODE_SUBGROUP :
3386 (use_subgroups16 && w == DMMV_WG_SIZE_LARGE) ? SHADER_REDUCTION_MODE_HYBRID :
3387 SHADER_REDUCTION_MODE_SHMEM;
3388
3389 for (uint32_t i = 0; i < mul_mat_vec_max_cols; ++i) {
3390 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f32_f32", arr_dmmv_f32_f32_f32_len[reduc], arr_dmmv_f32_f32_f32_data[reduc], "main", 4, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size);
3391 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f32_f32", arr_dmmv_f16_f32_f32_len[reduc], arr_dmmv_f16_f32_f32_data[reduc], "main", 4, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size);
3392 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f32_f32", arr_dmmv_bf16_f32_f32_len[reduc], arr_dmmv_bf16_f32_f32_data[reduc], "main", 4, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size);
3393 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f32_f32", arr_dmmv_q4_0_f32_f32_len[reduc], arr_dmmv_q4_0_f32_f32_data[reduc], "main", 4, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
3394 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f32_f32", arr_dmmv_q4_1_f32_f32_len[reduc], arr_dmmv_q4_1_f32_f32_data[reduc], "main", 4, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
3395 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f32_f32", arr_dmmv_q5_0_f32_f32_len[reduc], arr_dmmv_q5_0_f32_f32_data[reduc], "main", 4, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
3396 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_f32_f32", arr_dmmv_q5_1_f32_f32_len[reduc], arr_dmmv_q5_1_f32_f32_data[reduc], "main", 4, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
3397 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_f32_f32", arr_dmmv_q8_0_f32_f32_len[reduc], arr_dmmv_q8_0_f32_f32_data[reduc], "main", 4, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup, 1*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
3398 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_f32_f32", arr_dmmv_q2_k_f32_f32_len[reduc16], arr_dmmv_q2_k_f32_f32_data[reduc16], "main", 4, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
3399 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_f32_f32", arr_dmmv_q3_k_f32_f32_len[reduc16], arr_dmmv_q3_k_f32_f32_data[reduc16], "main", 4, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
3400 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f32_f32", arr_dmmv_q4_k_f32_f32_len[reduc16], arr_dmmv_q4_k_f32_f32_data[reduc16], "main", 4, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
3401 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f32_f32", arr_dmmv_q5_k_f32_f32_len[reduc16], arr_dmmv_q5_k_f32_f32_data[reduc16], "main", 4, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
3402 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f32_f32", arr_dmmv_q6_k_f32_f32_len[reduc16], arr_dmmv_q6_k_f32_f32_data[reduc16], "main", 4, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
3403 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ1_S][i], "mul_mat_vec_iq1_s_f32_f32", arr_dmmv_iq1_s_f32_f32_len[reduc16], arr_dmmv_iq1_s_f32_f32_data[reduc16], "main", 4, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
3404 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ1_M][i], "mul_mat_vec_iq1_m_f32_f32", arr_dmmv_iq1_m_f32_f32_len[reduc16], arr_dmmv_iq1_m_f32_f32_data[reduc16], "main", 4, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
3405 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ2_XXS][i], "mul_mat_vec_iq2_xxs_f32_f32", arr_dmmv_iq2_xxs_f32_f32_len[reduc16], arr_dmmv_iq2_xxs_f32_f32_data[reduc16], "main", 4, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
3406 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ2_XS][i], "mul_mat_vec_iq2_xs_f32_f32", arr_dmmv_iq2_xs_f32_f32_len[reduc16], arr_dmmv_iq2_xs_f32_f32_data[reduc16], "main", 4, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
3407 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ2_S][i], "mul_mat_vec_iq2_s_f32_f32", arr_dmmv_iq2_s_f32_f32_len[reduc16], arr_dmmv_iq2_s_f32_f32_data[reduc16], "main", 4, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
3408 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ3_XXS][i], "mul_mat_vec_iq3_xxs_f32_f32", arr_dmmv_iq3_xxs_f32_f32_len[reduc16], arr_dmmv_iq3_xxs_f32_f32_data[reduc16], "main", 4, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
3409 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f32_f32", arr_dmmv_iq3_s_f32_f32_len[reduc16], arr_dmmv_iq3_s_f32_f32_data[reduc16], "main", 4, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
3410 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f32_f32", arr_dmmv_iq4_xs_f32_f32_len[reduc16], arr_dmmv_iq4_xs_f32_f32_data[reduc16], "main", 4, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
3411 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f32_f32", arr_dmmv_iq4_nl_f32_f32_len[reduc16], arr_dmmv_iq4_nl_f32_f32_data[reduc16], "main", 4, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
3412 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_f32_f32", arr_dmmv_mxfp4_f32_f32_len[reduc16], arr_dmmv_mxfp4_f32_f32_data[reduc16], "main", 4, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
3413
3414 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32", arr_dmmv_f32_f16_f32_len[reduc], arr_dmmv_f32_f16_f32_data[reduc], "main", 4, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size);
3415 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32", arr_dmmv_f16_f16_f32_len[reduc], arr_dmmv_f16_f16_f32_data[reduc], "main", 4, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size);
3416 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f16_f32", arr_dmmv_bf16_f16_f32_len[reduc], arr_dmmv_bf16_f16_f32_data[reduc], "main", 4, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size);
3417 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f16_f32", arr_dmmv_q4_0_f16_f32_len[reduc], arr_dmmv_q4_0_f16_f32_data[reduc], "main", 4, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
3418 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f16_f32", arr_dmmv_q4_1_f16_f32_len[reduc], arr_dmmv_q4_1_f16_f32_data[reduc], "main", 4, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
3419 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f16_f32", arr_dmmv_q5_0_f16_f32_len[reduc], arr_dmmv_q5_0_f16_f32_data[reduc], "main", 4, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
3420 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_f16_f32", arr_dmmv_q5_1_f16_f32_len[reduc], arr_dmmv_q5_1_f16_f32_data[reduc], "main", 4, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
3421 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_f16_f32", arr_dmmv_q8_0_f16_f32_len[reduc], arr_dmmv_q8_0_f16_f32_data[reduc], "main", 4, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup, 1*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
3422 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_f16_f32", arr_dmmv_q2_k_f16_f32_len[reduc16], arr_dmmv_q2_k_f16_f32_data[reduc16], "main", 4, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
3423 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_f16_f32", arr_dmmv_q3_k_f16_f32_len[reduc16], arr_dmmv_q3_k_f16_f32_data[reduc16], "main", 4, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
3424 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f16_f32", arr_dmmv_q4_k_f16_f32_len[reduc16], arr_dmmv_q4_k_f16_f32_data[reduc16], "main", 4, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
3425 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f16_f32", arr_dmmv_q5_k_f16_f32_len[reduc16], arr_dmmv_q5_k_f16_f32_data[reduc16], "main", 4, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
3426 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f16_f32", arr_dmmv_q6_k_f16_f32_len[reduc16], arr_dmmv_q6_k_f16_f32_data[reduc16], "main", 4, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
3427 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ1_S][i], "mul_mat_vec_iq1_s_f16_f32", arr_dmmv_iq1_s_f16_f32_len[reduc16], arr_dmmv_iq1_s_f16_f32_data[reduc16], "main", 4, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
3428 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ1_M][i], "mul_mat_vec_iq1_m_f16_f32", arr_dmmv_iq1_m_f16_f32_len[reduc16], arr_dmmv_iq1_m_f16_f32_data[reduc16], "main", 4, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
3429 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ2_XXS][i], "mul_mat_vec_iq2_xxs_f16_f32", arr_dmmv_iq2_xxs_f16_f32_len[reduc16], arr_dmmv_iq2_xxs_f16_f32_data[reduc16], "main", 4, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
3430 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ2_XS][i], "mul_mat_vec_iq2_xs_f16_f32", arr_dmmv_iq2_xs_f16_f32_len[reduc16], arr_dmmv_iq2_xs_f16_f32_data[reduc16], "main", 4, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
3431 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ2_S][i], "mul_mat_vec_iq2_s_f16_f32", arr_dmmv_iq2_s_f16_f32_len[reduc16], arr_dmmv_iq2_s_f16_f32_data[reduc16], "main", 4, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
3432 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ3_XXS][i], "mul_mat_vec_iq3_xxs_f16_f32", arr_dmmv_iq3_xxs_f16_f32_len[reduc16], arr_dmmv_iq3_xxs_f16_f32_data[reduc16], "main", 4, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
3433 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f16_f32", arr_dmmv_iq3_s_f16_f32_len[reduc16], arr_dmmv_iq3_s_f16_f32_data[reduc16], "main", 4, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
3434 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f16_f32", arr_dmmv_iq4_xs_f16_f32_len[reduc16], arr_dmmv_iq4_xs_f16_f32_data[reduc16], "main", 4, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
3435 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f16_f32", arr_dmmv_iq4_nl_f16_f32_len[reduc16], arr_dmmv_iq4_nl_f16_f32_data[reduc16], "main", 4, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
3436 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_f16_f32", arr_dmmv_mxfp4_f16_f32_len[reduc16], arr_dmmv_mxfp4_f16_f32_data[reduc16], "main", 4, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
3437
3438#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
3439 if (device->integer_dot_product) {
3440 const uint32_t subgroup_size_int = (device->vendor_id == VK_VENDOR_ID_INTEL && device->subgroup_size_control) ? device->subgroup_min_size : device->subgroup_size;
3441 const uint32_t wg_size_subgroup_int = (w == DMMV_WG_SIZE_SUBGROUP) ? subgroup_size_int : (subgroup_size_int * 4);
3442
3443 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_q8_1_f32", arr_dmmv_q4_0_q8_1_f32_len[reduc], arr_dmmv_q4_0_q8_1_f32_data[reduc], "main", 4, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int);
3444 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_q8_1_f32", arr_dmmv_q4_1_q8_1_f32_len[reduc], arr_dmmv_q4_1_q8_1_f32_data[reduc], "main", 4, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int);
3445 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_q8_1_f32", arr_dmmv_q5_0_q8_1_f32_len[reduc], arr_dmmv_q5_0_q8_1_f32_data[reduc], "main", 4, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int);
3446 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_q8_1_f32", arr_dmmv_q5_1_q8_1_f32_len[reduc], arr_dmmv_q5_1_q8_1_f32_data[reduc], "main", 4, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int);
3447 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_q8_1_f32", arr_dmmv_q8_0_q8_1_f32_len[reduc], arr_dmmv_q8_0_q8_1_f32_data[reduc], "main", 4, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int);
3448 }
3449#endif // GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT
3450 }
3451 }
3452
3453 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", 5, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
3454 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32", mul_mat_vec_id_f16_f32_len, mul_mat_vec_id_f16_f32_data, "main", 5, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
3455 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_BF16], "mul_mat_vec_id_bf16_f32", mul_mat_vec_id_bf16_f32_len, mul_mat_vec_id_bf16_f32_data, "main", 5, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
3456 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32", mul_mat_vec_id_q4_0_f32_len, mul_mat_vec_id_q4_0_f32_data, "main", 5, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
3457 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_f32", mul_mat_vec_id_q4_1_f32_len, mul_mat_vec_id_q4_1_f32_data, "main", 5, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
3458 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32", mul_mat_vec_id_q5_0_f32_len, mul_mat_vec_id_q5_0_f32_data, "main", 5, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
3459 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_f32", mul_mat_vec_id_q5_1_f32_len, mul_mat_vec_id_q5_1_f32_data, "main", 5, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
3460 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_f32", mul_mat_vec_id_q8_0_f32_len, mul_mat_vec_id_q8_0_f32_data, "main", 5, sizeof(vk_mat_vec_id_push_constants), {1*rm_stdq, 1, 1}, {device->subgroup_size, 1*rm_stdq}, 1, true);
3461 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_f32", mul_mat_vec_id_q2_k_f32_len, mul_mat_vec_id_q2_k_f32_data, "main", 5, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
3462 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_f32", mul_mat_vec_id_q3_k_f32_len, mul_mat_vec_id_q3_k_f32_data, "main", 5, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
3463 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_f32", mul_mat_vec_id_q4_k_f32_len, mul_mat_vec_id_q4_k_f32_data, "main", 5, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
3464 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_f32", mul_mat_vec_id_q5_k_f32_len, mul_mat_vec_id_q5_k_f32_data, "main", 5, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
3465 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_f32", mul_mat_vec_id_q6_k_f32_len, mul_mat_vec_id_q6_k_f32_data, "main", 5, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
3466 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ1_S], "mul_mat_vec_id_iq1_s_f32", mul_mat_vec_id_iq1_s_f32_len, mul_mat_vec_id_iq1_s_f32_data, "main", 5, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
3467 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ1_M], "mul_mat_vec_id_iq1_m_f32", mul_mat_vec_id_iq1_m_f32_len, mul_mat_vec_id_iq1_m_f32_data, "main", 5, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
3468 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_XXS], "mul_mat_vec_id_iq2_xxs_f32", mul_mat_vec_id_iq2_xxs_f32_len, mul_mat_vec_id_iq2_xxs_f32_data, "main", 5, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
3469 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_XS], "mul_mat_vec_id_iq2_xs_f32", mul_mat_vec_id_iq2_xs_f32_len, mul_mat_vec_id_iq2_xs_f32_data, "main", 5, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
3470 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_S], "mul_mat_vec_id_iq2_s_f32", mul_mat_vec_id_iq2_s_f32_len, mul_mat_vec_id_iq2_s_f32_data, "main", 5, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
3471 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_XXS], "mul_mat_vec_id_iq3_xxs_f32", mul_mat_vec_id_iq3_xxs_f32_len, mul_mat_vec_id_iq3_xxs_f32_data, "main", 5, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
3472 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_S], "mul_mat_vec_id_iq3_s_f32", mul_mat_vec_id_iq3_s_f32_len, mul_mat_vec_id_iq3_s_f32_data, "main", 5, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
3473 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_XS], "mul_mat_vec_id_iq4_xs_f32", mul_mat_vec_id_iq4_xs_f32_len, mul_mat_vec_id_iq4_xs_f32_data, "main", 5, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
3474 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", 5, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
3475 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_MXFP4], "mul_mat_vec_id_mxfp4_f32", mul_mat_vec_id_mxfp4_f32_len, mul_mat_vec_id_mxfp4_f32_data, "main", 5, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
3476
3477 // dequant shaders
3478 ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", dequant_f32_len, dequant_f32_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
3479 ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q4_0], "dequant_q4_0", dequant_q4_0_len, dequant_q4_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
3480 ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q4_1], "dequant_q4_1", dequant_q4_1_len, dequant_q4_1_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
3481 ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q5_0], "dequant_q5_0", dequant_q5_0_len, dequant_q5_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
3482 ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q5_1], "dequant_q5_1", dequant_q5_1_len, dequant_q5_1_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
3483 ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q8_0], "dequant_q8_0", dequant_q8_0_len, dequant_q8_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
3484 ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q2_K], "dequant_q2_k", dequant_q2_k_len, dequant_q2_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1);
3485 ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q3_K], "dequant_q3_k", dequant_q3_k_len, dequant_q3_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1);
3486 ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q4_K], "dequant_q4_k", dequant_q4_k_len, dequant_q4_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
3487 ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q5_K], "dequant_q5_k", dequant_q5_k_len, dequant_q5_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1);
3488 ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q6_K], "dequant_q6_k", dequant_q6_k_len, dequant_q6_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1);
3489 ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ1_S], "dequant_iq1_s", dequant_iq1_s_len, dequant_iq1_s_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
3490 ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ1_M], "dequant_iq1_m", dequant_iq1_m_len, dequant_iq1_m_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
3491 ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ2_XXS], "dequant_iq2_xxs", dequant_iq2_xxs_len, dequant_iq2_xxs_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
3492 ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ2_XS], "dequant_iq2_xs", dequant_iq2_xs_len, dequant_iq2_xs_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
3493 ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ2_S], "dequant_iq2_s", dequant_iq2_s_len, dequant_iq2_s_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
3494 ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ3_XXS], "dequant_iq3_xxs", dequant_iq3_xxs_len, dequant_iq3_xxs_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
3495 ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ3_S], "dequant_iq3_s", dequant_iq3_s_len, dequant_iq3_s_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
3496 ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_XS], "dequant_iq4_xs", dequant_iq4_xs_len, dequant_iq4_xs_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
3497 ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_NL], "dequant_iq4_nl", dequant_iq4_nl_len, dequant_iq4_nl_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
3498 ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_MXFP4], "dequant_mxfp4", dequant_mxfp4_len, dequant_mxfp4_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
3499
3500 // get_rows
3501 ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F32 ], "get_rows_f32", get_rows_f32_len, get_rows_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
3502 ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F16 ], "get_rows_f16", get_rows_f16_len, get_rows_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
3503 ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_BF16], "get_rows_bf16", get_rows_bf16_len, get_rows_bf16_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
3504 ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_0], "get_rows_q4_0", get_rows_q4_0_len, get_rows_q4_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
3505 ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_1], "get_rows_q4_1", get_rows_q4_1_len, get_rows_q4_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
3506 ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_0], "get_rows_q5_0", get_rows_q5_0_len, get_rows_q5_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
3507 ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_1], "get_rows_q5_1", get_rows_q5_1_len, get_rows_q5_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
3508 ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q8_0], "get_rows_q8_0", get_rows_q8_0_len, get_rows_q8_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
3509 ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q2_K], "get_rows_q2_k", get_rows_q2_k_len, get_rows_q2_k_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
3510 ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q3_K], "get_rows_q3_k", get_rows_q3_k_len, get_rows_q3_k_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
3511 ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_K], "get_rows_q4_k", get_rows_q4_k_len, get_rows_q4_k_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
3512 ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_K], "get_rows_q5_k", get_rows_q5_k_len, get_rows_q5_k_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
3513 ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q6_K], "get_rows_q6_k", get_rows_q6_k_len, get_rows_q6_k_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
3514 ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ1_S], "get_rows_iq1_s", get_rows_iq1_s_len, get_rows_iq1_s_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
3515 ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ1_M], "get_rows_iq1_m", get_rows_iq1_m_len, get_rows_iq1_m_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
3516 ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ2_XXS], "get_rows_iq2_xxs", get_rows_iq2_xxs_len, get_rows_iq2_xxs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
3517 ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ2_XS], "get_rows_iq2_xs", get_rows_iq2_xs_len, get_rows_iq2_xs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
3518 ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ2_S], "get_rows_iq2_s", get_rows_iq2_s_len, get_rows_iq2_s_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
3519 ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ3_XXS], "get_rows_iq3_xxs", get_rows_iq3_xxs_len, get_rows_iq3_xxs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
3520 ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ3_S], "get_rows_iq3_s", get_rows_iq3_s_len, get_rows_iq3_s_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
3521 ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_XS], "get_rows_iq4_xs", get_rows_iq4_xs_len, get_rows_iq4_xs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
3522 ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl", get_rows_iq4_nl_len, get_rows_iq4_nl_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
3523 ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_MXFP4], "get_rows_mxfp4", get_rows_mxfp4_len, get_rows_mxfp4_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
3524
3525 ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f32_f32", get_rows_f32_f32_len, get_rows_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
3526 ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F16 ], "get_rows_f16_f32", get_rows_f16_f32_len, get_rows_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
3527 ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_BF16], "get_rows_bf16_f32", get_rows_bf16_f32_len, get_rows_bf16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
3528 ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_0], "get_rows_q4_0_f32", get_rows_q4_0_f32_len, get_rows_q4_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
3529 ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_1], "get_rows_q4_1_f32", get_rows_q4_1_f32_len, get_rows_q4_1_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
3530 ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_0], "get_rows_q5_0_f32", get_rows_q5_0_f32_len, get_rows_q5_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
3531 ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_1], "get_rows_q5_1_f32", get_rows_q5_1_f32_len, get_rows_q5_1_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
3532 ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q8_0], "get_rows_q8_0_f32", get_rows_q8_0_f32_len, get_rows_q8_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
3533 ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q2_K], "get_rows_q2_k_f32", get_rows_q2_k_f32_len, get_rows_q2_k_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
3534 ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q3_K], "get_rows_q3_k_f32", get_rows_q3_k_f32_len, get_rows_q3_k_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
3535 ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_K], "get_rows_q4_k_f32", get_rows_q4_k_f32_len, get_rows_q4_k_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
3536 ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_K], "get_rows_q5_k_f32", get_rows_q5_k_f32_len, get_rows_q5_k_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
3537 ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q6_K], "get_rows_q6_k_f32", get_rows_q6_k_f32_len, get_rows_q6_k_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
3538 ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ1_S], "get_rows_iq1_s_f32", get_rows_iq1_s_f32_len, get_rows_iq1_s_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
3539 ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ1_M], "get_rows_iq1_m_f32", get_rows_iq1_m_f32_len, get_rows_iq1_m_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
3540 ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ2_XXS], "get_rows_iq2_xxs_f32", get_rows_iq2_xxs_f32_len, get_rows_iq2_xxs_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
3541 ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ2_XS], "get_rows_iq2_xs_f32", get_rows_iq2_xs_f32_len, get_rows_iq2_xs_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
3542 ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ2_S], "get_rows_iq2_s_f32", get_rows_iq2_s_f32_len, get_rows_iq2_s_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
3543 ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ3_XXS], "get_rows_iq3_xxs_f32", get_rows_iq3_xxs_f32_len, get_rows_iq3_xxs_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
3544 ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ3_S], "get_rows_iq3_s_f32", get_rows_iq3_s_f32_len, get_rows_iq3_s_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
3545 ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_XS], "get_rows_iq4_xs_f32", get_rows_iq4_xs_f32_len, get_rows_iq4_xs_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
3546 ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
3547 ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_MXFP4], "get_rows_mxfp4_f32", get_rows_mxfp4_f32_len, get_rows_mxfp4_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
3548
3549 ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
3550 ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 3, 5 * sizeof(uint32_t), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true);
3551
3552 if (device->subgroup_clustered && device->subgroup_require_full_support) {
3553 ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1, "quantize_q8_1", quantize_q8_1_subgroup_len, quantize_q8_1_subgroup_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1, true, true);
3554 ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1_x4, "quantize_q8_1_x4", quantize_q8_1_x4_subgroup_len, quantize_q8_1_x4_subgroup_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1, true, true);
3555 } else {
3556 ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1, "quantize_q8_1", quantize_q8_1_len, quantize_q8_1_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1);
3557 ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1_x4, "quantize_q8_1_x4", quantize_q8_1_x4_len, quantize_q8_1_x4_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1);
3558 }
3559
3560 for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) {
3561 if (device->subgroup_arithmetic && device->subgroup_require_full_support) {
3562 ggml_vk_create_pipeline2(device, device->pipeline_mul_mat_vec_p021_f16_f32[i], "mul_mat_vec_p021_f16_f32"+std::to_string(val: i+1), mul_mat_vec_p021_f16_f32_subgroup_add_len, mul_mat_vec_p021_f16_f32_subgroup_add_data, "main", 4, 7 * sizeof(uint32_t), {1, 1, 1}, {device->subgroup_size, i + 1}, 1, true, true);
3563 } else {
3564 ggml_vk_create_pipeline2(device, device->pipeline_mul_mat_vec_p021_f16_f32[i], "mul_mat_vec_p021_f16_f32"+std::to_string(val: i+1), mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 4, 7 * sizeof(uint32_t), {1, 1, 1}, {device->subgroup_size, i + 1}, 1, true);
3565 }
3566 }
3567 ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 4, 13 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
3568
3569 ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
3570 ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
3571
3572 ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1, true);
3573 ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_f32, "rms_norm_mul_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1, true);
3574 ggml_vk_create_pipeline(device, device->pipeline_rms_norm_partials_f32, "rms_norm_partials_f32", rms_norm_partials_f32_len, rms_norm_partials_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1, true);
3575 ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_partials_f32, "rms_norm_mul_partials_f32", rms_norm_partials_f32_len, rms_norm_partials_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1, true);
3576
3577 if (device->float_controls_rte_fp16 &&
3578 sizeof(vk_op_rms_norm_mul_rope_push_constants) <= device->properties.limits.maxPushConstantsSize) {
3579 ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_rope_f32_f32, "rms_norm_mul_rope_f32_f32", rms_norm_mul_rope_f32_f32_len, rms_norm_mul_rope_f32_f32_data, "main", 7, sizeof(vk_op_rms_norm_mul_rope_push_constants), {1, 1, 1}, {0, 1}, 1, true);
3580 ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_rope_f32_f16, "rms_norm_mul_rope_f32_f16", rms_norm_mul_rope_f32_f16_rte_len, rms_norm_mul_rope_f32_f16_rte_data, "main", 7, sizeof(vk_op_rms_norm_mul_rope_push_constants), {1, 1, 1}, {0, 1}, 1, true);
3581 }
3582
3583 ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
3584 ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
3585
3586 ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
3587 ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
3588 ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f16, "cpy_f16_f16", cpy_f16_f16_len, cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
3589 ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f32, "cpy_f16_f32", cpy_f16_f32_len, cpy_f16_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
3590 ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_bf16,"cpy_f32_bf16",cpy_f32_bf16_len,cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
3591 ggml_vk_create_pipeline(device, device->pipeline_cpy_i32_f32, "cpy_i32_f32", cpy_i32_f32_len, cpy_i32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
3592 ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_i32, "cpy_f32_i32", cpy_f32_i32_len, cpy_f32_i32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
3593
3594 ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f32, "contig_cpy_f32_f32", contig_cpy_f32_f32_len, contig_cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
3595 ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f16, "contig_cpy_f32_f16", contig_cpy_f32_f16_len, contig_cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
3596 ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f16, "contig_cpy_f16_f16", contig_cpy_f16_f16_len, contig_cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
3597 ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f32, "contig_cpy_f16_f32", contig_cpy_f16_f32_len, contig_cpy_f16_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
3598 ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_bf16,"contig_cpy_f32_bf16",contig_cpy_f32_bf16_len,contig_cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
3599 ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_i32_f32, "contig_cpy_i32_f32", contig_cpy_i32_f32_len, contig_cpy_i32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
3600 ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_i32, "contig_cpy_f32_i32", contig_cpy_f32_i32_len, contig_cpy_f32_i32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
3601
3602 if (device->float_controls_rte_fp16) {
3603 ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_rte_len, cpy_f32_q4_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
3604 ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_rte_len, cpy_f32_q4_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
3605 ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_rte_len, cpy_f32_q5_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
3606 ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_rte_len, cpy_f32_q5_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
3607 ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_rte_len, cpy_f32_q8_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
3608 ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_rte_len, cpy_f32_iq4_nl_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
3609 } else {
3610 ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_len, cpy_f32_q4_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
3611 ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_len, cpy_f32_q4_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
3612 ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_len, cpy_f32_q5_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
3613 ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_len, cpy_f32_q5_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
3614 ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_len, cpy_f32_q8_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
3615 ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_len, cpy_f32_iq4_nl_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
3616 }
3617
3618#define SET_ROWS(itype, rte) \
3619 ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_F32], "set_rows_f32" #itype, set_rows_f32 ## itype ## rte ## _len, set_rows_f32 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
3620 ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_F16], "set_rows_f16" #itype, set_rows_f16 ## itype ## rte ## _len, set_rows_f16 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
3621 ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_BF16], "set_rows_bf16" #itype, set_rows_bf16 ## itype ## rte ## _len, set_rows_bf16 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
3622 ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q4_0], "set_rows_q4_0" #itype, set_rows_q4_0 ## itype ## rte ## _len, set_rows_q4_0 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
3623 ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q4_1], "set_rows_q4_1" #itype, set_rows_q4_1 ## itype ## rte ## _len, set_rows_q4_1 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
3624 ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q5_0], "set_rows_q5_0" #itype, set_rows_q5_0 ## itype ## rte ## _len, set_rows_q5_0 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
3625 ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q5_1], "set_rows_q5_1" #itype, set_rows_q5_1 ## itype ## rte ## _len, set_rows_q5_1 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
3626 ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q8_0], "set_rows_q8_0" #itype, set_rows_q8_0 ## itype ## rte ## _len, set_rows_q8_0 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
3627 ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_IQ4_NL], "set_rows_iq4_nl" #itype, set_rows_iq4_nl ## itype ## rte ## _len, set_rows_iq4_nl ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
3628
3629 if (device->float_controls_rte_fp16) {
3630 SET_ROWS(_i32, _rte)
3631 SET_ROWS(_i64, _rte)
3632 } else {
3633 SET_ROWS(_i32, )
3634 SET_ROWS(_i64, )
3635 }
3636#undef SET_ROWS
3637
3638
3639 ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q4_0], "cpy_q4_0_f32", cpy_q4_0_f32_len, cpy_q4_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(type: GGML_TYPE_Q4_0), 1, 1}, {}, 1);
3640 ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q4_1], "cpy_q4_1_f32", cpy_q4_1_f32_len, cpy_q4_1_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(type: GGML_TYPE_Q4_1), 1, 1}, {}, 1);
3641 ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q5_0], "cpy_q5_0_f32", cpy_q5_0_f32_len, cpy_q5_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(type: GGML_TYPE_Q5_0), 1, 1}, {}, 1);
3642 ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q5_1], "cpy_q5_1_f32", cpy_q5_1_f32_len, cpy_q5_1_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(type: GGML_TYPE_Q5_1), 1, 1}, {}, 1);
3643 ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q8_0], "cpy_q8_0_f32", cpy_q8_0_f32_len, cpy_q8_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(type: GGML_TYPE_Q8_0), 1, 1}, {}, 1);
3644 ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_IQ4_NL], "cpy_iq4_nl_f32", cpy_iq4_nl_f32_len, cpy_iq4_nl_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(type: GGML_TYPE_IQ4_NL), 1, 1}, {}, 1);
3645
3646 auto get_suffix = [](bool src0_f16, bool src1_f16, bool dst_f16) {
3647 std::string s;
3648 s += std::string(src0_f16 ? "_f16" : "_f32");
3649 s += std::string(src1_f16 ? "_f16" : "_f32");
3650 s += std::string(dst_f16 ? "_f16" : "_f32");
3651 return s;
3652 };
3653
3654 bool rte = device->float_controls_rte_fp16;
3655#define CREATE_BINARY(name, namemod, spec, bindings) \
3656 for (int s0 : {0,1}) for (int s1 : {0,1}) for (int d : {0,1}) \
3657 ggml_vk_create_pipeline2(device, device->pipeline_ ## name ## namemod[s0][s1][d], \
3658 #name + get_suffix(s0, s1, d) + #namemod, name ## _len[s0][s1][d][rte], name ## _data[s0][s1][d][rte], \
3659 "main", (bindings), sizeof(vk_op_binary_push_constants), {512, 1, 1}, spec, 1);
3660
3661 CREATE_BINARY(add, , {0}, 4)
3662 CREATE_BINARY(add, _norepeat, {1}, 4)
3663 CREATE_BINARY(sub, , {0}, 3)
3664 CREATE_BINARY(sub, _norepeat, {1}, 3)
3665 CREATE_BINARY(mul, , {0}, 3)
3666 CREATE_BINARY(mul, _norepeat, {1}, 3)
3667 CREATE_BINARY(div, , {0}, 3)
3668 CREATE_BINARY(div, _norepeat, {1}, 3)
3669 CREATE_BINARY(add_rms, , {0}, 4)
3670 CREATE_BINARY(add_rms, _norepeat, {1}, 4)
3671#undef CREATE_BINARY
3672
3673 if (device->multi_add) {
3674 for (uint32_t i = 0; i < MAX_FUSED_ADDS; ++i) {
3675 ggml_vk_create_pipeline2(device, device->pipeline_multi_add[i], "multi_add_f32_" + std::to_string(val: i+1), multi_add_f32_len, multi_add_f32_data, "main", MAX_PARAMETER_COUNT, sizeof(vk_op_multi_add_push_constants), {512, 1, 1}, {i+2}, 1);
3676 ggml_vk_create_pipeline2(device, device->pipeline_multi_add_rms[i], "multi_add_rms_f32_" + std::to_string(val: i+1), multi_add_rms_f32_len, multi_add_rms_f32_data, "main", MAX_PARAMETER_COUNT, sizeof(vk_op_multi_add_push_constants), {512, 1, 1}, {i+2}, 1);
3677 }
3678 }
3679
3680 ggml_vk_create_pipeline(device, device->pipeline_add_id_f32, "add_id_f32", add_id_f32_len, add_id_f32_data, "main", 4, sizeof(vk_op_add_id_push_constants), {1, 1, 1}, {}, 1);
3681
3682 ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
3683
3684 ggml_vk_create_pipeline(device, device->pipeline_concat_f32, "concat_f32", concat_f32_len, concat_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
3685 ggml_vk_create_pipeline(device, device->pipeline_concat_f16, "concat_f16", concat_f16_len, concat_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
3686 ggml_vk_create_pipeline(device, device->pipeline_concat_i32, "concat_i32", concat_i32_len, concat_i32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
3687
3688 ggml_vk_create_pipeline(device, device->pipeline_upscale_nearest_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_NEAREST}, 1);
3689 ggml_vk_create_pipeline(device, device->pipeline_upscale_bilinear_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BILINEAR}, 1);
3690
3691 ggml_vk_create_pipeline(device, device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
3692
3693 ggml_vk_create_pipeline(device, device->pipeline_sqr_f32, "sqr_f32", sqr_f32_len, sqr_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
3694 ggml_vk_create_pipeline(device, device->pipeline_sqrt_f32, "sqrt_f32", sqrt_f32_len, sqrt_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
3695 ggml_vk_create_pipeline(device, device->pipeline_sin_f32, "sin_f32", sin_f32_len, sin_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
3696 ggml_vk_create_pipeline(device, device->pipeline_cos_f32, "cos_f32", cos_f32_len, cos_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
3697
3698 ggml_vk_create_pipeline(device, device->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
3699
3700 ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_pad_push_constants), {512, 1, 1}, {}, 1);
3701
3702 ggml_vk_create_pipeline(device, device->pipeline_roll_f32, "roll_f32", roll_f32_len, roll_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
3703
3704 ggml_vk_create_pipeline(device, device->pipeline_repeat_f32, "repeat_f32", repeat_f32_len, repeat_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
3705 ggml_vk_create_pipeline(device, device->pipeline_repeat_back_f32, "repeat_back_f32", repeat_back_f32_len, repeat_back_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
3706
3707#define CREATE_UNARY(name) \
3708 ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \
3709 ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
3710
3711 CREATE_UNARY(gelu)
3712 CREATE_UNARY(gelu_erf)
3713 CREATE_UNARY(gelu_quick)
3714 CREATE_UNARY(silu)
3715 CREATE_UNARY(relu)
3716 CREATE_UNARY(tanh)
3717 CREATE_UNARY(sigmoid)
3718 CREATE_UNARY(hardsigmoid)
3719 CREATE_UNARY(hardswish)
3720#undef CREATE_UNARY
3721
3722#define CREATE_UNARY_RTE(name) \
3723 if (device->float_controls_rte_fp16) { \
3724 ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32_rte", name ## _f32_rte_len, name ## _f32_rte_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \
3725 ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16_rte", name ## _f16_rte_len, name ## _f16_rte_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \
3726 } else { \
3727 ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \
3728 ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \
3729 }
3730 CREATE_UNARY_RTE(exp)
3731#undef CREATE_UNARY_RTE
3732
3733#define CREATE_GLU(name) \
3734 if (device->float_controls_rte_fp16) { \
3735 ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32_rte", name ## _f32_rte_len, name ## _f32_rte_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
3736 ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16_rte", name ## _f16_rte_len, name ## _f16_rte_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
3737 } else { \
3738 ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
3739 ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
3740 }
3741
3742 CREATE_GLU(geglu)
3743 CREATE_GLU(reglu)
3744 CREATE_GLU(swiglu)
3745 CREATE_GLU(swiglu_oai)
3746 CREATE_GLU(geglu_erf)
3747 CREATE_GLU(geglu_quick)
3748#undef CREATE_GLU
3749
3750 ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
3751 ggml_vk_create_pipeline(device, device->pipeline_silu_back_f32, "silu_back_f32", silu_back_f32_len, silu_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
3752
3753 ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {1, 512, 1}, {}, 1, true);
3754
3755 ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
3756 ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_wg512, "soft_max_f32_wg512", soft_max_f32_len, soft_max_f32_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
3757 ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
3758 ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
3759 ggml_vk_create_pipeline(device, device->pipeline_soft_max_back_f32, "soft_max_back_f32", soft_max_back_f32_len, soft_max_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1, true);
3760
3761 ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
3762 ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
3763 ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32, "rope_multi_f32", rope_multi_f32_len, rope_multi_f32_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
3764 ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f32, "rope_vision_f32", rope_vision_f32_len, rope_vision_f32_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
3765
3766 if (device->float_controls_rte_fp16) {
3767 ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_rte_len, rope_norm_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
3768 ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_rte_len, rope_neox_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
3769 ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f16, "rope_multi_f16", rope_multi_f16_rte_len, rope_multi_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
3770 ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_rte_len, rope_vision_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
3771
3772 ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32_f16, "rope_norm_f32_f16", rope_norm_f32_f16_rte_len, rope_norm_f32_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
3773 ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32_f16, "rope_neox_f32_f16", rope_neox_f32_f16_rte_len, rope_neox_f32_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
3774 } else {
3775 ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_len, rope_norm_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
3776 ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
3777 ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f16, "rope_multi_f16", rope_multi_f16_len, rope_multi_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
3778 ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_len, rope_vision_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
3779
3780 ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32_f16, "rope_norm_f32_f16", rope_norm_f32_f16_len, rope_norm_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
3781 ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32_f16, "rope_neox_f32_f16", rope_neox_f32_f16_len, rope_neox_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
3782 }
3783
3784 for (uint32_t i = 0; i < num_argsort_pipelines; ++i) {
3785 ggml_vk_create_pipeline2(device, device->pipeline_argsort_f32[i], "argsort_f32_"+std::to_string(val: i), argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1u<<i, 1, 1}, {1u<<i, i}, 1, true);
3786 }
3787
3788 ggml_vk_create_pipeline(device, device->pipeline_argmax_f32, "argmax_f32", argmax_f32_len, argmax_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
3789
3790 ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
3791
3792 ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1);
3793
3794#define IM2COL(bda) \
3795 ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32 ## bda ## _len, im2col_f32 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \
3796 ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32, "im2col_3d_f32", im2col_3d_f32 ## bda ## _len, im2col_3d_f32 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \
3797 if (device->float_controls_rte_fp16) { \
3798 ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte ## bda ## _len, im2col_f32_f16_rte ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \
3799 ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16_rte ## bda ## _len, im2col_3d_f32_f16_rte ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \
3800 } else { \
3801 ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16 ## bda ## _len, im2col_f32_f16 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \
3802 ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16 ## bda ## _len, im2col_3d_f32_f16 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \
3803 }
3804 if (device->shader_int64 && device->buffer_device_address) {
3805 IM2COL(_bda)
3806 } else {
3807 IM2COL()
3808 }
3809
3810 ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1);
3811
3812 ggml_vk_create_pipeline(device, device->pipeline_conv_transpose_1d_f32, "conv_transpose_1d_f32", conv_transpose_1d_f32_len, conv_transpose_1d_f32_data, "main", 3, sizeof(vk_op_conv_transpose_1d_push_constants), {1, 1, 1}, {}, 1);
3813
3814 ggml_vk_create_pipeline(device, device->pipeline_pool2d_f32, "pool2d_f32", pool2d_f32_len, pool2d_f32_data, "main", 2, sizeof(vk_op_pool2d_push_constants), {512, 1, 1}, {}, 1);
3815
3816 ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main", 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
3817
3818 ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
3819
3820 if (device->subgroup_arithmetic && device->subgroup_require_full_support) {
3821 ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size, 16}, 1, true, true);
3822 ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1, true, true);
3823 } else {
3824 ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size, 16}, 1, true, true);
3825 ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1, true, true);
3826 }
3827
3828 ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_f32, "ssm_conv_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 3, sizeof(vk_op_ssm_conv_push_constants), {32, 1, 1}, {32}, 1);
3829
3830 ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
3831
3832 ggml_vk_create_pipeline(device, device->pipeline_opt_step_sgd_f32, "opt_step_sgd_f32", opt_step_sgd_f32_len, opt_step_sgd_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
3833
3834 // conv2d, conv_transpose_2d
3835 for (uint32_t s = 0; s < CONV_SHAPE_COUNT; ++s) {
3836 uint32_t conv2d_WG_SIZE = 256;
3837 uint32_t conv2d_BS_K = 128;
3838 uint32_t conv2d_BS_CRS = 16;
3839 uint32_t use_collectives = 0; // Enables subgroup ops for preventing the re-calculation of indices.
3840 uint32_t conv2d_BS_NPQ = 128;
3841 uint32_t conv2d_TS_K = 8;
3842 uint32_t conv2d_SHMEM_PAD = 4;
3843 bool conv2d_UNROLL = true;
3844
3845#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
3846 if (device->coopmat2) {
3847 conv2d_SHMEM_PAD = 8; // 8 float16_t
3848 }
3849#endif
3850
3851 if (device->vendor_id == VK_VENDOR_ID_INTEL) {
3852 conv2d_SHMEM_PAD = 0;
3853 conv2d_UNROLL = false;
3854 } else if (device->vendor_id == VK_VENDOR_ID_AMD) {
3855 conv2d_SHMEM_PAD = device->architecture == vk_device_architecture::AMD_GCN ? 1 : 4;
3856 }
3857
3858 switch (s) {
3859 default:
3860 case CONV_SHAPE_128x128:
3861 conv2d_BS_K = 128;
3862 conv2d_BS_NPQ = 128;
3863 conv2d_BS_CRS = 16;
3864 if (device->vendor_id == VK_VENDOR_ID_AMD && device->architecture != vk_device_architecture::AMD_GCN) {
3865 conv2d_UNROLL = false;
3866 }
3867 break;
3868 case CONV_SHAPE_64x32:
3869 conv2d_BS_K = 64;
3870 conv2d_BS_NPQ = 32;
3871 conv2d_BS_CRS = 32;
3872 conv2d_TS_K = 4;
3873 break;
3874 case CONV_SHAPE_32x256:
3875 conv2d_BS_K = 32;
3876 conv2d_BS_NPQ = 256;
3877 conv2d_BS_CRS = 16;
3878 break;
3879 }
3880
3881 // Use collectives on pre-Turing NVIDIA GPUs and GCN AMD cards, which had slower integer math.
3882 bool allow_collectives_nv = device->vendor_id != VK_VENDOR_ID_NVIDIA ||
3883 device->architecture == vk_device_architecture::NVIDIA_PRE_TURING;
3884 bool allow_collectives_amd = device->vendor_id != VK_VENDOR_ID_AMD ||
3885 device->architecture == vk_device_architecture::AMD_GCN;
3886
3887 if (device->subgroup_shuffle &&
3888 device->vendor_id != VK_VENDOR_ID_INTEL && // Do not enable collectives on Intel, see PR 14316.
3889 allow_collectives_nv &&
3890 allow_collectives_amd) {
3891 use_collectives = 1;
3892 conv2d_BS_CRS = std::min(
3893 a: device->subgroup_size,
3894 b: conv2d_BS_CRS); // CRS block size should be capped at subgroup size for correctness when shuffle is used.
3895 }
3896
3897 uint32_t conv2d_shmem_req =
3898 (conv2d_BS_K * (conv2d_BS_CRS + conv2d_SHMEM_PAD) + conv2d_BS_CRS * (conv2d_BS_NPQ + conv2d_SHMEM_PAD)) * sizeof(float);
3899 if (device->properties.limits.maxComputeSharedMemorySize < conv2d_shmem_req) {
3900 conv2d_BS_CRS = 8;
3901 if (use_collectives) {
3902 conv2d_BS_CRS = std::min(a: device->subgroup_size, b: conv2d_BS_CRS);
3903 }
3904 }
3905
3906 std::array<uint32_t, 3> wg_denoms = { conv2d_BS_K, conv2d_BS_NPQ, 1 };
3907 std::vector<uint32_t> spec_constants = { conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives, conv2d_SHMEM_PAD };
3908
3909#define CREATE_CONV(name, type_suffix, spv_suffix) \
3910 ggml_vk_create_pipeline( \
3911 device, device->pipeline_##name##type_suffix[s], #name #type_suffix, \
3912 name##type_suffix##spv_suffix##_len, name##type_suffix##spv_suffix##_data, "main", 3, \
3913 sizeof(vk_op_##name##_push_constants), wg_denoms, spec_constants, 1, true, use_collectives);
3914#define CREATE_CONVS(spv_suffix) \
3915 CREATE_CONV(conv2d, _f32, spv_suffix) \
3916 CREATE_CONV(conv2d, _f16_f32, spv_suffix) \
3917 if (device->properties.limits.maxPushConstantsSize >= sizeof(vk_op_conv_transpose_2d_push_constants)) { \
3918 CREATE_CONV(conv_transpose_2d, _f32, spv_suffix) \
3919 CREATE_CONV(conv_transpose_2d, _f16_f32, spv_suffix) \
3920 }
3921#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
3922 if (device->coopmat2) {
3923 CREATE_CONVS(_cm2)
3924 } else
3925#endif
3926 if (conv2d_UNROLL) {
3927 CREATE_CONVS(_unroll)
3928 } else {
3929 CREATE_CONVS( )
3930 }
3931#undef CREATE_CONV
3932#undef CREATE_CONVS
3933 }
3934
3935 ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
3936 ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f32, "conv2d_dw_cwhn_f32", conv2d_dw_cwhn_f32_len, conv2d_dw_cwhn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
3937 ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f16_f32, "conv2d_dw_whcn_f16_f32", conv2d_dw_whcn_f16_f32_len, conv2d_dw_whcn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
3938 ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f16_f32, "conv2d_dw_cwhn_f16_f32", conv2d_dw_cwhn_f16_f32_len, conv2d_dw_cwhn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
3939
3940 for (uint32_t i = 0; i < num_topk_moe_pipelines; ++i) {
3941 ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX], "topk_moe_f32_early_softmax_"+std::to_string(val: i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0, 0}, 1, true, true);
3942 ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX_NORM], "topk_moe_f32_early_softmax_norm"+std::to_string(val: i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 1, 0}, 1, true, true);
3943 ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_LATE_SOFTMAX], "topk_moe_f32_late_softmax"+std::to_string(val: i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0, 1}, 1, true, true);
3944 }
3945
3946 for (auto &c : compiles) {
3947 c.wait();
3948 }
3949}
3950
3951static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props, vk_device_architecture arch);
3952
3953static vk_device ggml_vk_get_device(size_t idx) {
3954 VK_LOG_DEBUG("ggml_vk_get_device(" << idx << ")");
3955
3956 if (vk_instance.devices[idx] == nullptr) {
3957 VK_LOG_DEBUG("Initializing new vk_device");
3958 vk_device device = std::make_shared<vk_device_struct>();
3959 vk_instance.devices[idx] = device;
3960
3961#ifdef GGML_VULKAN_MEMORY_DEBUG
3962 device->memory_logger = std::unique_ptr<vk_memory_logger>(new vk_memory_logger());
3963#endif
3964 if (vk_perf_logger_enabled) {
3965 device->perf_logger = std::unique_ptr<vk_perf_logger>(new vk_perf_logger());
3966 }
3967
3968 size_t dev_num = vk_instance.device_indices[idx];
3969
3970 std::vector<vk::PhysicalDevice> physical_devices = vk_instance.instance.enumeratePhysicalDevices();
3971
3972 if (dev_num >= physical_devices.size()) {
3973 std::cerr << "ggml_vulkan: Device with index " << dev_num << " does not exist." << std::endl;
3974 throw std::runtime_error("Device not found");
3975 }
3976
3977 device->physical_device = physical_devices[dev_num];
3978 const std::vector<vk::ExtensionProperties> ext_props = device->physical_device.enumerateDeviceExtensionProperties();
3979
3980 device->architecture = get_device_architecture(device: device->physical_device);
3981
3982 const char* GGML_VK_PREFER_HOST_MEMORY = getenv(name: "GGML_VK_PREFER_HOST_MEMORY");
3983 device->prefer_host_memory = GGML_VK_PREFER_HOST_MEMORY != nullptr;
3984
3985 const char* GGML_VK_DISABLE_HOST_VISIBLE_VIDMEM = getenv(name: "GGML_VK_DISABLE_HOST_VISIBLE_VIDMEM");
3986 device->disable_host_visible_vidmem = GGML_VK_DISABLE_HOST_VISIBLE_VIDMEM != nullptr;
3987
3988 const char* GGML_VK_ALLOW_SYSMEM_FALLBACK = getenv(name: "GGML_VK_ALLOW_SYSMEM_FALLBACK");
3989 device->allow_sysmem_fallback = GGML_VK_ALLOW_SYSMEM_FALLBACK != nullptr;
3990
3991 const char* GGML_VK_DISABLE_GRAPH_OPTIMIZE = getenv(name: "GGML_VK_DISABLE_GRAPH_OPTIMIZE");
3992 device->disable_graph_optimize = GGML_VK_DISABLE_GRAPH_OPTIMIZE != nullptr;
3993
3994 bool fp16_storage = false;
3995 bool fp16_compute = false;
3996 bool maintenance4_support = false;
3997 bool sm_builtins = false;
3998 bool amd_shader_core_properties2 = false;
3999 bool pipeline_robustness = false;
4000 bool coopmat2_support = false;
4001 bool pipeline_executable_properties_support = false;
4002 device->coopmat_support = false;
4003 device->integer_dot_product = false;
4004 bool bfloat16_support = false;
4005
4006 for (const auto& properties : ext_props) {
4007 if (strcmp(s1: "VK_KHR_maintenance4", s2: properties.extensionName) == 0) {
4008 maintenance4_support = true;
4009 } else if (strcmp(s1: "VK_KHR_16bit_storage", s2: properties.extensionName) == 0) {
4010 fp16_storage = true;
4011 } else if (strcmp(s1: "VK_KHR_shader_float16_int8", s2: properties.extensionName) == 0) {
4012 fp16_compute = true;
4013 } else if (strcmp(s1: "VK_NV_shader_sm_builtins", s2: properties.extensionName) == 0) {
4014 sm_builtins = true;
4015 } else if (strcmp(s1: "VK_AMD_shader_core_properties2", s2: properties.extensionName) == 0) {
4016 amd_shader_core_properties2 = true;
4017 } else if (strcmp(s1: "VK_EXT_pipeline_robustness", s2: properties.extensionName) == 0) {
4018 pipeline_robustness = true;
4019 } else if (strcmp(s1: "VK_EXT_subgroup_size_control", s2: properties.extensionName) == 0) {
4020 device->subgroup_size_control = true;
4021#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
4022 } else if (strcmp(s1: "VK_KHR_cooperative_matrix", s2: properties.extensionName) == 0 &&
4023 !getenv(name: "GGML_VK_DISABLE_COOPMAT")) {
4024 device->coopmat_support = true;
4025 device->coopmat_m = 0;
4026 device->coopmat_n = 0;
4027 device->coopmat_k = 0;
4028#endif
4029#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
4030 } else if (strcmp(s1: "VK_NV_cooperative_matrix2", s2: properties.extensionName) == 0 &&
4031 !getenv(name: "GGML_VK_DISABLE_COOPMAT2")) {
4032 coopmat2_support = true;
4033#endif
4034#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
4035 } else if (strcmp(s1: "VK_KHR_shader_integer_dot_product", s2: properties.extensionName) == 0 &&
4036 !getenv(name: "GGML_VK_DISABLE_INTEGER_DOT_PRODUCT")) {
4037 device->integer_dot_product = true;
4038#endif
4039#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
4040 } else if (strcmp(s1: "VK_KHR_shader_bfloat16", s2: properties.extensionName) == 0 &&
4041 !getenv(name: "GGML_VK_DISABLE_BFLOAT16")) {
4042 bfloat16_support = true;
4043#endif
4044 } else if (strcmp(s1: "VK_KHR_pipeline_executable_properties", s2: properties.extensionName) == 0) {
4045 pipeline_executable_properties_support = true;
4046 }
4047 }
4048
4049 vk::PhysicalDeviceProperties2 props2;
4050 vk::PhysicalDeviceMaintenance3Properties props3;
4051 vk::PhysicalDeviceMaintenance4Properties props4;
4052 vk::PhysicalDeviceSubgroupProperties subgroup_props;
4053 vk::PhysicalDeviceDriverProperties driver_props;
4054 vk::PhysicalDeviceShaderSMBuiltinsPropertiesNV sm_props;
4055 vk::PhysicalDeviceShaderCoreProperties2AMD amd_shader_core_properties2_props;
4056 vk::PhysicalDeviceVulkan11Properties vk11_props;
4057 vk::PhysicalDeviceVulkan12Properties vk12_props;
4058 vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props;
4059 vk::PhysicalDeviceShaderIntegerDotProductPropertiesKHR shader_integer_dot_product_props;
4060
4061 props2.pNext = &props3;
4062 props3.pNext = &subgroup_props;
4063 subgroup_props.pNext = &driver_props;
4064 driver_props.pNext = &vk11_props;
4065 vk11_props.pNext = &vk12_props;
4066
4067 VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&vk12_props;
4068
4069 if (maintenance4_support) {
4070 last_struct->pNext = (VkBaseOutStructure *)&props4;
4071 last_struct = (VkBaseOutStructure *)&props4;
4072 }
4073 if (sm_builtins) {
4074 last_struct->pNext = (VkBaseOutStructure *)&sm_props;
4075 last_struct = (VkBaseOutStructure *)&sm_props;
4076 }
4077 if (amd_shader_core_properties2) {
4078 last_struct->pNext = (VkBaseOutStructure *)&amd_shader_core_properties2_props;
4079 last_struct = (VkBaseOutStructure *)&amd_shader_core_properties2_props;
4080 }
4081 if (device->subgroup_size_control) {
4082 last_struct->pNext = (VkBaseOutStructure *)&subgroup_size_control_props;
4083 last_struct = (VkBaseOutStructure *)&subgroup_size_control_props;
4084 }
4085
4086#if defined(VK_NV_cooperative_matrix2)
4087 vk::PhysicalDeviceCooperativeMatrix2PropertiesNV coopmat2_props;
4088 if (coopmat2_support) {
4089 last_struct->pNext = (VkBaseOutStructure *)&coopmat2_props;
4090 last_struct = (VkBaseOutStructure *)&coopmat2_props;
4091 }
4092#endif
4093
4094 if (device->integer_dot_product) {
4095 last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_props;
4096 last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_props;
4097 }
4098
4099 device->physical_device.getProperties2(pProperties: &props2);
4100 device->properties = props2.properties;
4101 device->vendor_id = device->properties.vendorID;
4102 device->driver_id = driver_props.driverID;
4103
4104 const char* GGML_VK_FORCE_MAX_ALLOCATION_SIZE = getenv(name: "GGML_VK_FORCE_MAX_ALLOCATION_SIZE");
4105
4106 if (GGML_VK_FORCE_MAX_ALLOCATION_SIZE != nullptr) {
4107 device->max_memory_allocation_size = std::stoull(str: GGML_VK_FORCE_MAX_ALLOCATION_SIZE);
4108 } else if (maintenance4_support) {
4109 device->max_memory_allocation_size = std::min(a: props3.maxMemoryAllocationSize, b: props4.maxBufferSize);
4110 } else {
4111 device->max_memory_allocation_size = props3.maxMemoryAllocationSize;
4112 }
4113
4114 const char* GGML_VK_FORCE_MAX_BUFFER_SIZE = getenv(name: "GGML_VK_FORCE_MAX_BUFFER_SIZE");
4115
4116 if (GGML_VK_FORCE_MAX_BUFFER_SIZE != nullptr) {
4117 device->max_buffer_size = std::stoull(str: GGML_VK_FORCE_MAX_BUFFER_SIZE);
4118 } else if (maintenance4_support) {
4119 device->max_buffer_size = props4.maxBufferSize;
4120 } else {
4121 device->max_buffer_size = device->max_memory_allocation_size;
4122 }
4123
4124 const char* GGML_VK_SUBALLOCATION_BLOCK_SIZE = getenv(name: "GGML_VK_SUBALLOCATION_BLOCK_SIZE");
4125
4126 if (GGML_VK_SUBALLOCATION_BLOCK_SIZE != nullptr) {
4127 device->suballocation_block_size = std::stoull(str: GGML_VK_SUBALLOCATION_BLOCK_SIZE);
4128 } else {
4129 // Limit batching of allocations to 1GB by default to avoid fragmentation issues
4130 device->suballocation_block_size = 1024*1024*1024;
4131 }
4132 device->suballocation_block_size = std::min(a: device->suballocation_block_size, b: device->max_memory_allocation_size);
4133
4134 device->subgroup_size = subgroup_props.subgroupSize;
4135 device->uma = device->properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
4136 if (sm_builtins) {
4137 device->shader_core_count = sm_props.shaderSMCount;
4138 } else if (amd_shader_core_properties2) {
4139 device->shader_core_count = amd_shader_core_properties2_props.activeComputeUnitCount;
4140 } else {
4141 device->shader_core_count = 0;
4142 }
4143 device->float_controls_rte_fp16 = vk12_props.shaderRoundingModeRTEFloat16;
4144
4145 device->subgroup_arithmetic = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
4146 (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eArithmetic);
4147#ifdef __APPLE__
4148 // Workaround for subgroup arithmetic failing on MoltenVK with AMD GPUs (issue 15846)
4149 if (device->vendor_id == VK_VENDOR_ID_AMD) {
4150 device->subgroup_arithmetic = false;
4151 }
4152#endif
4153 device->subgroup_shuffle = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
4154 (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eShuffle);
4155 device->subgroup_clustered = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
4156 (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eClustered);
4157
4158 device->subgroup_ballot = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
4159 (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eBallot);
4160
4161 const bool force_disable_f16 = getenv(name: "GGML_VK_DISABLE_F16") != nullptr;
4162
4163 device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
4164
4165 if (!ggml_vk_khr_cooperative_matrix_support(props: device->properties, driver_props, arch: device->architecture)) {
4166 device->coopmat_support = false;
4167 }
4168
4169 device->integer_dot_product = device->integer_dot_product && shader_integer_dot_product_props.integerDotProduct4x8BitPackedSignedAccelerated;
4170
4171 std::vector<vk::QueueFamilyProperties> queue_family_props = device->physical_device.getQueueFamilyProperties();
4172
4173 // Try to find a non-graphics compute queue and transfer-focused queues
4174 const uint32_t compute_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, required: vk::QueueFlagBits::eCompute, avoid: vk::QueueFlagBits::eGraphics, compute_index: -1, min_num_queues: 1);
4175 const uint32_t transfer_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, required: vk::QueueFlagBits::eTransfer, avoid: vk::QueueFlagBits::eCompute | vk::QueueFlagBits::eGraphics, compute_index: compute_queue_family_index, min_num_queues: 1);
4176
4177 const float priorities[] = { 1.0f, 1.0f };
4178 device->single_queue = compute_queue_family_index == transfer_queue_family_index && queue_family_props[compute_queue_family_index].queueCount == 1;
4179
4180 std::vector<vk::DeviceQueueCreateInfo> device_queue_create_infos;
4181 if (compute_queue_family_index != transfer_queue_family_index) {
4182 device_queue_create_infos.push_back(x: {vk::DeviceQueueCreateFlags(), compute_queue_family_index, 1, priorities});
4183 device_queue_create_infos.push_back(x: {vk::DeviceQueueCreateFlags(), transfer_queue_family_index, 1, priorities + 1});
4184 } else if(!device->single_queue) {
4185 device_queue_create_infos.push_back(x: {vk::DeviceQueueCreateFlags(), compute_queue_family_index, 2, priorities});
4186 } else {
4187 device_queue_create_infos.push_back(x: {vk::DeviceQueueCreateFlags(), compute_queue_family_index, 1, priorities});
4188 }
4189 vk::DeviceCreateInfo device_create_info;
4190 std::vector<const char *> device_extensions;
4191 vk::PhysicalDeviceFeatures device_features = device->physical_device.getFeatures();
4192
4193 VkPhysicalDeviceFeatures2 device_features2;
4194 device_features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2;
4195 device_features2.pNext = nullptr;
4196 device_features2.features = (VkPhysicalDeviceFeatures)device_features;
4197
4198 VkPhysicalDeviceVulkan11Features vk11_features;
4199 vk11_features.pNext = nullptr;
4200 vk11_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_1_FEATURES;
4201 device_features2.pNext = &vk11_features;
4202
4203 VkPhysicalDeviceVulkan12Features vk12_features;
4204 vk12_features.pNext = nullptr;
4205 vk12_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_2_FEATURES;
4206 vk11_features.pNext = &vk12_features;
4207
4208 last_struct = (VkBaseOutStructure *)&vk12_features;
4209
4210 VkPhysicalDevicePipelineRobustnessFeaturesEXT pl_robustness_features;
4211 pl_robustness_features.pNext = nullptr;
4212 pl_robustness_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PIPELINE_ROBUSTNESS_FEATURES_EXT;
4213 pl_robustness_features.pipelineRobustness = VK_FALSE;
4214
4215 if (pipeline_robustness) {
4216 last_struct->pNext = (VkBaseOutStructure *)&pl_robustness_features;
4217 last_struct = (VkBaseOutStructure *)&pl_robustness_features;
4218 device_extensions.push_back(x: "VK_EXT_pipeline_robustness");
4219 }
4220
4221 VkPhysicalDeviceSubgroupSizeControlFeaturesEXT subgroup_size_control_features;
4222 subgroup_size_control_features.pNext = nullptr;
4223 subgroup_size_control_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SUBGROUP_SIZE_CONTROL_FEATURES_EXT;
4224 subgroup_size_control_features.computeFullSubgroups = false;
4225 subgroup_size_control_features.subgroupSizeControl = false;
4226
4227 if (device->subgroup_size_control) {
4228 last_struct->pNext = (VkBaseOutStructure *)&subgroup_size_control_features;
4229 last_struct = (VkBaseOutStructure *)&subgroup_size_control_features;
4230 }
4231
4232#if defined(VK_KHR_cooperative_matrix)
4233 VkPhysicalDeviceCooperativeMatrixFeaturesKHR coopmat_features;
4234 coopmat_features.pNext = nullptr;
4235 coopmat_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_FEATURES_KHR;
4236 coopmat_features.cooperativeMatrix = VK_FALSE;
4237
4238 if (device->coopmat_support) {
4239 last_struct->pNext = (VkBaseOutStructure *)&coopmat_features;
4240 last_struct = (VkBaseOutStructure *)&coopmat_features;
4241 }
4242#endif
4243
4244#if defined(VK_NV_cooperative_matrix2)
4245 VkPhysicalDeviceCooperativeMatrix2FeaturesNV coopmat2_features {};
4246 coopmat2_features.pNext = nullptr;
4247 coopmat2_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_2_FEATURES_NV;
4248 if (coopmat2_support) {
4249 last_struct->pNext = (VkBaseOutStructure *)&coopmat2_features;
4250 last_struct = (VkBaseOutStructure *)&coopmat2_features;
4251 device_extensions.push_back(x: "VK_NV_cooperative_matrix2");
4252 }
4253#endif
4254
4255#if defined(VK_KHR_shader_bfloat16)
4256 VkPhysicalDeviceShaderBfloat16FeaturesKHR bfloat16_features {};
4257 bfloat16_features.pNext = nullptr;
4258 bfloat16_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_BFLOAT16_FEATURES_KHR;
4259 if (bfloat16_support) {
4260 last_struct->pNext = (VkBaseOutStructure *)&bfloat16_features;
4261 last_struct = (VkBaseOutStructure *)&bfloat16_features;
4262 device_extensions.push_back(x: "VK_KHR_shader_bfloat16");
4263 }
4264#endif
4265
4266 VkPhysicalDeviceMaintenance4Features maint4_features {};
4267 maint4_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_MAINTENANCE_4_FEATURES;
4268 if (maintenance4_support) {
4269 last_struct->pNext = (VkBaseOutStructure *)&maint4_features;
4270 last_struct = (VkBaseOutStructure *)&maint4_features;
4271 device_extensions.push_back(x: "VK_KHR_maintenance4");
4272 }
4273
4274 VkPhysicalDeviceShaderIntegerDotProductFeaturesKHR shader_integer_dot_product_features {};
4275 shader_integer_dot_product_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_INTEGER_DOT_PRODUCT_FEATURES_KHR;
4276 if (device->integer_dot_product) {
4277 last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_features;
4278 last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_features;
4279 device_extensions.push_back(x: "VK_KHR_shader_integer_dot_product");
4280 }
4281
4282 VkPhysicalDevicePipelineExecutablePropertiesFeaturesKHR pep_features {};
4283 pep_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PIPELINE_EXECUTABLE_PROPERTIES_FEATURES_KHR;
4284 if (pipeline_executable_properties_support) {
4285 last_struct->pNext = (VkBaseOutStructure *)&pep_features;
4286 last_struct = (VkBaseOutStructure *)&pep_features;
4287 device_extensions.push_back(x: "VK_KHR_pipeline_executable_properties");
4288 }
4289
4290 vkGetPhysicalDeviceFeatures2(physicalDevice: device->physical_device, pFeatures: &device_features2);
4291
4292 device->pipeline_executable_properties_support = pipeline_executable_properties_support;
4293
4294 device->fp16 = device->fp16 && vk12_features.shaderFloat16;
4295
4296#if defined(VK_KHR_shader_bfloat16)
4297 device->bf16 = bfloat16_support && bfloat16_features.shaderBFloat16Type;
4298#else
4299 device->bf16 = false;
4300#endif
4301
4302 device->pipeline_robustness = pl_robustness_features.pipelineRobustness;
4303
4304 device->multi_add = vk12_props.shaderRoundingModeRTEFloat16 &&
4305 device->properties.limits.maxPushConstantsSize >= sizeof(vk_op_multi_add_push_constants) &&
4306 getenv(name: "GGML_VK_DISABLE_MULTI_ADD") == nullptr;
4307
4308 device->shader_int64 = device_features2.features.shaderInt64;
4309 device->buffer_device_address = vk12_features.bufferDeviceAddress;
4310
4311 if (device->subgroup_size_control) {
4312 device->subgroup_min_size = subgroup_size_control_props.minSubgroupSize;
4313 device->subgroup_max_size = subgroup_size_control_props.maxSubgroupSize;
4314 device_extensions.push_back(x: "VK_EXT_subgroup_size_control");
4315 }
4316
4317 device->subgroup_size_control = device->subgroup_size_control &&
4318 (subgroup_size_control_props.requiredSubgroupSizeStages & vk::ShaderStageFlagBits::eCompute) &&
4319 subgroup_size_control_features.subgroupSizeControl;
4320
4321 device->subgroup_require_full_support = subgroup_size_control_features.computeFullSubgroups;
4322
4323#if defined(VK_KHR_cooperative_matrix)
4324 device->coopmat_support = device->coopmat_support && coopmat_features.cooperativeMatrix;
4325
4326 // coopmat1 fa shader currently assumes 32 invocations per subgroup
4327 device->coopmat1_fa_support = device->coopmat_support && device->subgroup_require_full_support &&
4328 device->subgroup_size_control && device->subgroup_min_size <= 32 &&
4329 device->subgroup_max_size >= 32;
4330#endif
4331
4332 if (coopmat2_support) {
4333#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
4334 if (coopmat2_features.cooperativeMatrixWorkgroupScope &&
4335 coopmat2_features.cooperativeMatrixFlexibleDimensions &&
4336 coopmat2_features.cooperativeMatrixReductions &&
4337 coopmat2_features.cooperativeMatrixConversions &&
4338 coopmat2_features.cooperativeMatrixPerElementOperations &&
4339 coopmat2_features.cooperativeMatrixTensorAddressing &&
4340 coopmat2_features.cooperativeMatrixBlockLoads &&
4341 vk12_features.bufferDeviceAddress) {
4342
4343 std::vector<VkCooperativeMatrixFlexibleDimensionsPropertiesNV> flexible_dimensions;
4344 uint32_t count = 0;
4345
4346 PFN_vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV
4347 _vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV =
4348 (PFN_vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV)
4349 vk_instance.instance.getProcAddr(pName: "vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV");
4350
4351 _vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV(device->physical_device, &count, nullptr);
4352
4353 VkCooperativeMatrixFlexibleDimensionsPropertiesNV empty_prop {};
4354 empty_prop.sType = VK_STRUCTURE_TYPE_COOPERATIVE_MATRIX_FLEXIBLE_DIMENSIONS_PROPERTIES_NV;
4355 flexible_dimensions.resize(new_size: count, x: empty_prop);
4356
4357 _vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV(device->physical_device, &count, flexible_dimensions.data());
4358
4359 bool found_fp16_128 = false,
4360 found_fp16_256 = false,
4361 found_fp32_128 = false,
4362 found_fp32_256 = false;
4363 // need to support fp16*fp16 with fp16/fp32 accumulator, for workgroupsize 128
4364 // with 32x16x16 and 256 with 32x32x16.
4365 for (auto &prop : flexible_dimensions) {
4366 if (prop.saturatingAccumulation == VK_FALSE &&
4367 prop.scope == VK_SCOPE_WORKGROUP_KHR &&
4368 prop.AType == VK_COMPONENT_TYPE_FLOAT16_KHR &&
4369 prop.BType == VK_COMPONENT_TYPE_FLOAT16_KHR) {
4370
4371 if (prop.workgroupInvocations == 128 &&
4372 prop.MGranularity <= 32 &&
4373 prop.NGranularity <= 16 &&
4374 prop.KGranularity <= 16) {
4375 if (prop.CType == VK_COMPONENT_TYPE_FLOAT16_KHR &&
4376 prop.ResultType == VK_COMPONENT_TYPE_FLOAT16_KHR) {
4377 found_fp16_128 = true;
4378 }
4379 if (prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR &&
4380 prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR) {
4381 found_fp32_128 = true;
4382 }
4383 }
4384 if (prop.workgroupInvocations == 256 &&
4385 prop.MGranularity <= 32 &&
4386 prop.NGranularity <= 32 &&
4387 prop.KGranularity <= 16) {
4388 if (prop.CType == VK_COMPONENT_TYPE_FLOAT16_KHR &&
4389 prop.ResultType == VK_COMPONENT_TYPE_FLOAT16_KHR) {
4390 found_fp16_256 = true;
4391 }
4392 if (prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR &&
4393 prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR) {
4394 found_fp32_256 = true;
4395 }
4396 }
4397 }
4398 }
4399 if (found_fp16_128 && found_fp16_256 &&
4400 found_fp32_128 && found_fp32_256 &&
4401 coopmat2_props.cooperativeMatrixFlexibleDimensionsMaxDimension >= 512) {
4402 device->coopmat2 = true;
4403 }
4404 }
4405#endif
4406 }
4407
4408 if (!vk11_features.storageBuffer16BitAccess) {
4409 std::cerr << "ggml_vulkan: device " << GGML_VK_NAME << idx << " does not support 16-bit storage." << std::endl;
4410 throw std::runtime_error("Unsupported device");
4411 }
4412
4413 device_extensions.push_back(x: "VK_KHR_16bit_storage");
4414
4415#ifdef GGML_VULKAN_VALIDATE
4416 device_extensions.push_back("VK_KHR_shader_non_semantic_info");
4417#endif
4418
4419 if (device->fp16) {
4420 device_extensions.push_back(x: "VK_KHR_shader_float16_int8");
4421 }
4422
4423#if defined(VK_KHR_cooperative_matrix)
4424 if (device->coopmat_support) {
4425 // Query supported shapes
4426 std::vector<VkCooperativeMatrixPropertiesKHR> cm_props;
4427
4428 PFN_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR pfn_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR =
4429 (PFN_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR)vkGetInstanceProcAddr(instance: vk_instance.instance, pName: "vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR");
4430
4431 uint32_t cm_props_num;
4432
4433 pfn_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR(device->physical_device, &cm_props_num, nullptr);
4434
4435 cm_props.resize(new_size: cm_props_num);
4436
4437 for (auto& prop : cm_props) {
4438 prop.sType = VK_STRUCTURE_TYPE_COOPERATIVE_MATRIX_PROPERTIES_KHR;
4439 }
4440
4441 pfn_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR(device->physical_device, &cm_props_num, cm_props.data());
4442
4443 VK_LOG_DEBUG("ggml_vulkan: Cooperative Matrix Shapes: " << cm_props.size());
4444
4445 for (auto& prop : cm_props) {
4446 VK_LOG_DEBUG("ggml_vulkan: M: " << prop.MSize << " N: " << prop.NSize << " K: " << prop.KSize << " A: " << vk::to_string((vk::ComponentTypeKHR)prop.AType) << " B: " << vk::to_string((vk::ComponentTypeKHR)prop.BType) << " C: " << vk::to_string((vk::ComponentTypeKHR)prop.CType) << " Result: " << vk::to_string((vk::ComponentTypeKHR)prop.ResultType) << " saturatingAccumulation: " << prop.saturatingAccumulation << " scope: " << vk::to_string((vk::ScopeKHR)prop.scope));
4447
4448 if ((vk::ComponentTypeKHR)prop.AType == vk::ComponentTypeKHR::eFloat16 &&
4449 (vk::ComponentTypeKHR)prop.BType == vk::ComponentTypeKHR::eFloat16 &&
4450 (vk::ScopeKHR)prop.scope == vk::ScopeKHR::eSubgroup
4451 ) {
4452 if ((vk::ComponentTypeKHR)prop.CType == vk::ComponentTypeKHR::eFloat32 &&
4453 (vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eFloat32) {
4454 // coopmat sizes not set yet
4455 if (device->coopmat_m == 0) {
4456 device->coopmat_acc_f32_support = true;
4457 device->coopmat_m = prop.MSize;
4458 device->coopmat_n = prop.NSize;
4459 device->coopmat_k = prop.KSize;
4460 } else if (device->coopmat_m == prop.MSize && device->coopmat_n == prop.NSize && device->coopmat_k == prop.KSize) {
4461 // Only enable if shape is identical
4462 device->coopmat_acc_f32_support = true;
4463 }
4464 if (prop.MSize == 16 && prop.NSize == 16 && prop.KSize == 16) {
4465 device->coopmat_support_16x16x16_f32acc = true;
4466 }
4467 } else if ((vk::ComponentTypeKHR)prop.CType == vk::ComponentTypeKHR::eFloat16 &&
4468 (vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eFloat16) {
4469 // coopmat sizes not set yet
4470 if (device->coopmat_m == 0) {
4471 device->coopmat_acc_f16_support = true;
4472 device->coopmat_m = prop.MSize;
4473 device->coopmat_n = prop.NSize;
4474 device->coopmat_k = prop.KSize;
4475 } else if (device->coopmat_m == prop.MSize && device->coopmat_n == prop.NSize && device->coopmat_k == prop.KSize) {
4476 // Only enable if shape is identical
4477 device->coopmat_acc_f16_support = true;
4478 }
4479 if (prop.MSize == 16 && prop.NSize == 16 && prop.KSize == 16) {
4480 device->coopmat_support_16x16x16_f16acc = true;
4481 }
4482 }
4483 } else if ((vk::ComponentTypeKHR)prop.AType == vk::ComponentTypeKHR::eSint8 &&
4484 (vk::ComponentTypeKHR)prop.BType == vk::ComponentTypeKHR::eSint8 &&
4485 (vk::ComponentTypeKHR)prop.CType == vk::ComponentTypeKHR::eSint32 &&
4486 (vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eSint32 &&
4487 (vk::ScopeKHR)prop.scope == vk::ScopeKHR::eSubgroup &&
4488 device->coopmat_int_m == 0
4489 ) {
4490 device->coopmat_int_support = true;
4491 device->coopmat_int_m = prop.MSize;
4492 device->coopmat_int_n = prop.NSize;
4493 device->coopmat_int_k = prop.KSize;
4494 }
4495#if defined(VK_KHR_shader_bfloat16) && defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
4496 if (prop.AType == VK_COMPONENT_TYPE_BFLOAT16_KHR &&
4497 prop.BType == VK_COMPONENT_TYPE_BFLOAT16_KHR &&
4498 prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR &&
4499 prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR &&
4500 (vk::ScopeKHR)prop.scope == vk::ScopeKHR::eSubgroup
4501 ) {
4502 // coopmat sizes not set yet
4503 if (device->coopmat_m == 0) {
4504 device->coopmat_bf16_support = true;
4505 device->coopmat_m = prop.MSize;
4506 device->coopmat_n = prop.NSize;
4507 device->coopmat_k = prop.KSize;
4508 } else if (device->coopmat_m == prop.MSize && device->coopmat_n == prop.NSize && device->coopmat_k == prop.KSize) {
4509 // Only enable if shape is identical
4510 device->coopmat_bf16_support = true;
4511 }
4512 }
4513#endif
4514 }
4515
4516 if (device->coopmat_m == 0 || !device->coopmat_acc_f32_support) {
4517 // No suitable matmul mode found
4518 GGML_LOG_DEBUG("ggml_vulkan: WARNING: No suitable matrix core mode found. Disabling matrix cores.\n");
4519 device->coopmat_support = false;
4520 }
4521 if (getenv(name: "GGML_VK_DISABLE_BFLOAT16")) {
4522 device->coopmat_bf16_support = false;
4523 }
4524 }
4525
4526 if (device->coopmat_support) {
4527 device_extensions.push_back(x: "VK_KHR_cooperative_matrix");
4528 }
4529#if defined(VK_KHR_shader_bfloat16)
4530 if (device->coopmat_bf16_support) {
4531 device_extensions.push_back(x: "VK_KHR_shader_bfloat16");
4532 }
4533#endif
4534#endif
4535 device->name = GGML_VK_NAME + std::to_string(val: idx);
4536
4537 device_create_info = {
4538 vk::DeviceCreateFlags(),
4539 device_queue_create_infos,
4540 {},
4541 device_extensions
4542 };
4543 device_create_info.setPNext(&device_features2);
4544 device->device = device->physical_device.createDevice(createInfo: device_create_info);
4545
4546 // Queues
4547 ggml_vk_create_queue(device, q&: device->compute_queue, queue_family_index: compute_queue_family_index, queue_index: 0, stage_flags: { vk::PipelineStageFlagBits::eComputeShader | vk::PipelineStageFlagBits::eTransfer }, transfer_only: false);
4548
4549 // Shaders
4550 // Disable matmul tile sizes early if performance low or not supported
4551 for (uint32_t i = 0; i < GGML_TYPE_COUNT; ++i) {
4552 switch (device->vendor_id) {
4553#ifndef GGML_VULKAN_RUN_TESTS
4554 case VK_VENDOR_ID_AMD:
4555 case VK_VENDOR_ID_INTEL:
4556 device->mul_mat_l[i] = false;
4557 device->mul_mat_m[i] = true;
4558 device->mul_mat_s[i] = true;
4559 device->mul_mat_id_l[i] = false;
4560 device->mul_mat_id_m[i] = true;
4561 device->mul_mat_id_s[i] = true;
4562 break;
4563 case VK_VENDOR_ID_APPLE:
4564 device->mul_mat_l[i] = false;
4565 device->mul_mat_m[i] = true;
4566 device->mul_mat_s[i] = false;
4567 device->mul_mat_id_l[i] = false;
4568 device->mul_mat_id_m[i] = true;
4569 device->mul_mat_id_s[i] = false;
4570 break;
4571#endif
4572 default:
4573 device->mul_mat_l[i] = true;
4574 device->mul_mat_m[i] = true;
4575 device->mul_mat_s[i] = true;
4576 device->mul_mat_id_l[i] = true;
4577 device->mul_mat_id_m[i] = true;
4578 device->mul_mat_id_s[i] = true;
4579 break;
4580 }
4581 }
4582
4583
4584 std::vector<vk::DescriptorSetLayoutBinding> dsl_binding;
4585 std::vector<vk::DescriptorBindingFlags> dsl_binding_flags;
4586 for (uint32_t i = 0; i < MAX_PARAMETER_COUNT; i++) {
4587 dsl_binding.push_back(x: {i, vk::DescriptorType::eStorageBuffer, 1, vk::ShaderStageFlagBits::eCompute});
4588 dsl_binding_flags.push_back(x: {});
4589 }
4590
4591 vk::DescriptorSetLayoutBindingFlagsCreateInfo dslbfci = { dsl_binding_flags };
4592
4593 vk::DescriptorSetLayoutCreateInfo descriptor_set_layout_create_info(
4594 {},
4595 dsl_binding);
4596 descriptor_set_layout_create_info.setPNext(&dslbfci);
4597 device->dsl = device->device.createDescriptorSetLayout(createInfo: descriptor_set_layout_create_info);
4598
4599 ggml_vk_load_shaders(device);
4600
4601 if (!device->single_queue) {
4602 const uint32_t transfer_queue_index = compute_queue_family_index == transfer_queue_family_index ? 1 : 0;
4603 ggml_vk_create_queue(device, q&: device->transfer_queue, queue_family_index: transfer_queue_family_index, queue_index: transfer_queue_index, stage_flags: { vk::PipelineStageFlagBits::eTransfer }, transfer_only: true);
4604 } else {
4605 // TODO: Use pointer or reference to avoid copy
4606 device->transfer_queue.copyFrom(other&: device->compute_queue);
4607 device->transfer_queue.cmd_pool.init(device, q_: &device->transfer_queue);
4608 }
4609
4610 device->buffer_type = {
4611 /* .iface = */ ggml_backend_vk_buffer_type_interface,
4612 /* .device = */ ggml_backend_reg_dev_get(reg: ggml_backend_vk_reg(), index: idx),
4613 /* .context = */ new ggml_backend_vk_buffer_type_context{ .name: device->name, .device: device },
4614 };
4615
4616 device->fence = device->device.createFence(createInfo: {});
4617
4618 device->idx = idx;
4619
4620 device->disable_fusion = getenv(name: "GGML_VK_DISABLE_FUSION") != nullptr;
4621
4622 device->add_rms_fusion = !device->disable_fusion &&
4623 device->subgroup_arithmetic &&
4624 device->vendor_id != VK_VENDOR_ID_INTEL;
4625 device->partials_binding_alignment =
4626 std::max(a: 4u, b: (uint32_t)device->properties.limits.minStorageBufferOffsetAlignment);
4627
4628 device->mmvq_mode = 0;
4629 if (getenv(name: "GGML_VK_DISABLE_MMVQ")) {
4630 device->mmvq_mode = -1;
4631 } else if (getenv(name: "GGML_VK_FORCE_MMVQ")) {
4632 device->mmvq_mode = 1;
4633 }
4634
4635 return device;
4636 }
4637
4638 return vk_instance.devices[idx];
4639}
4640
4641static void ggml_vk_print_gpu_info(size_t idx) {
4642 GGML_ASSERT(idx < vk_instance.device_indices.size());
4643 size_t dev_num = vk_instance.device_indices[idx];
4644 VK_LOG_DEBUG("ggml_vk_print_gpu_info(" << dev_num << ")");
4645 GGML_ASSERT(vk_instance_initialized);
4646
4647 std::vector<vk::PhysicalDevice> devices = vk_instance.instance.enumeratePhysicalDevices();
4648
4649 if (dev_num >= devices.size()) {
4650 std::cerr << "ggml_vulkan: Device with index " << dev_num << " does not exist." << std::endl;
4651 throw std::runtime_error("Device not found");
4652 }
4653
4654 vk::PhysicalDevice physical_device = devices[dev_num];
4655 std::vector<vk::ExtensionProperties> ext_props = physical_device.enumerateDeviceExtensionProperties();
4656
4657 bool fp16_storage = false;
4658 bool fp16_compute = false;
4659 bool coopmat_support = false;
4660 bool coopmat2_support = false;
4661 bool integer_dot_product = false;
4662 bool bfloat16_support = false;
4663
4664 for (auto properties : ext_props) {
4665 if (strcmp(s1: "VK_KHR_16bit_storage", s2: properties.extensionName) == 0) {
4666 fp16_storage = true;
4667 } else if (strcmp(s1: "VK_KHR_shader_float16_int8", s2: properties.extensionName) == 0) {
4668 fp16_compute = true;
4669#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
4670 } else if (strcmp(s1: "VK_KHR_cooperative_matrix", s2: properties.extensionName) == 0 &&
4671 !getenv(name: "GGML_VK_DISABLE_COOPMAT")) {
4672 coopmat_support = true;
4673#endif
4674#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
4675 } else if (strcmp(s1: "VK_NV_cooperative_matrix2", s2: properties.extensionName) == 0 &&
4676 !getenv(name: "GGML_VK_DISABLE_COOPMAT2")) {
4677 coopmat2_support = true;
4678#endif
4679#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
4680 } else if (strcmp(s1: "VK_KHR_shader_integer_dot_product", s2: properties.extensionName) == 0 &&
4681 !getenv(name: "GGML_VK_DISABLE_INTEGER_DOT_PRODUCT")) {
4682 integer_dot_product = true;
4683#endif
4684#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
4685 } else if (strcmp(s1: "VK_KHR_shader_bfloat16", s2: properties.extensionName) == 0 &&
4686 !getenv(name: "GGML_VK_DISABLE_BFLOAT16")) {
4687 bfloat16_support = true;
4688#endif
4689 }
4690 }
4691
4692 const vk_device_architecture device_architecture = get_device_architecture(device: physical_device);
4693
4694 const char* GGML_VK_DISABLE_F16 = getenv(name: "GGML_VK_DISABLE_F16");
4695 bool force_disable_f16 = GGML_VK_DISABLE_F16 != nullptr;
4696
4697 bool fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
4698
4699 vk::PhysicalDeviceProperties2 props2;
4700 vk::PhysicalDeviceMaintenance3Properties props3;
4701 vk::PhysicalDeviceSubgroupProperties subgroup_props;
4702 vk::PhysicalDeviceDriverProperties driver_props;
4703 vk::PhysicalDeviceShaderIntegerDotProductPropertiesKHR shader_integer_dot_product_props;
4704 props2.pNext = &props3;
4705 props3.pNext = &subgroup_props;
4706 subgroup_props.pNext = &driver_props;
4707
4708 // Pointer to the last chain element
4709 VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&driver_props;
4710
4711 if (integer_dot_product) {
4712 last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_props;
4713 last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_props;
4714 }
4715
4716 physical_device.getProperties2(pProperties: &props2);
4717
4718 VkPhysicalDeviceFeatures2 device_features2;
4719 device_features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2;
4720 device_features2.pNext = nullptr;
4721
4722 VkPhysicalDeviceVulkan11Features vk11_features;
4723 vk11_features.pNext = nullptr;
4724 vk11_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_1_FEATURES;
4725 device_features2.pNext = &vk11_features;
4726
4727 VkPhysicalDeviceVulkan12Features vk12_features;
4728 vk12_features.pNext = nullptr;
4729 vk12_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_2_FEATURES;
4730 vk11_features.pNext = &vk12_features;
4731
4732 // Pointer to the last chain element
4733 last_struct = (VkBaseOutStructure *)&vk12_features;
4734
4735#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
4736 VkPhysicalDeviceCooperativeMatrixFeaturesKHR coopmat_features;
4737 coopmat_features.pNext = nullptr;
4738 coopmat_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_FEATURES_KHR;
4739 coopmat_features.cooperativeMatrix = VK_FALSE;
4740
4741 if (coopmat_support) {
4742 last_struct->pNext = (VkBaseOutStructure *)&coopmat_features;
4743 last_struct = (VkBaseOutStructure *)&coopmat_features;
4744 }
4745#endif
4746
4747 VkPhysicalDeviceShaderIntegerDotProductFeaturesKHR shader_integer_dot_product_features {};
4748 shader_integer_dot_product_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_INTEGER_DOT_PRODUCT_FEATURES_KHR;
4749 if (integer_dot_product) {
4750 last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_features;
4751 last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_features;
4752 }
4753
4754#if defined(VK_KHR_shader_bfloat16)
4755 VkPhysicalDeviceShaderBfloat16FeaturesKHR bfloat16_features {};
4756 bfloat16_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_BFLOAT16_FEATURES_KHR;
4757 if (bfloat16_support) {
4758 last_struct->pNext = (VkBaseOutStructure *)&bfloat16_features;
4759 last_struct = (VkBaseOutStructure *)&bfloat16_features;
4760 }
4761#endif
4762
4763 vkGetPhysicalDeviceFeatures2(physicalDevice: physical_device, pFeatures: &device_features2);
4764
4765 fp16 = fp16 && vk12_features.shaderFloat16;
4766
4767#if defined(VK_KHR_shader_bfloat16)
4768 bool bf16 = bfloat16_support && bfloat16_features.shaderBFloat16Type;
4769#else
4770 bool bf16 = false;
4771#endif
4772
4773 uint32_t default_subgroup_size = get_subgroup_size(pipeline_name: "", arch: device_architecture);
4774 const size_t subgroup_size = (default_subgroup_size != 0) ? default_subgroup_size : subgroup_props.subgroupSize;
4775 const bool uma = props2.properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
4776
4777 integer_dot_product = integer_dot_product
4778 && shader_integer_dot_product_props.integerDotProduct4x8BitPackedSignedAccelerated
4779 && shader_integer_dot_product_features.shaderIntegerDotProduct;
4780
4781 coopmat_support = coopmat_support
4782#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
4783 && coopmat_features.cooperativeMatrix
4784#endif
4785 && ggml_vk_khr_cooperative_matrix_support(props: props2.properties, driver_props, arch: device_architecture);
4786
4787 std::string matrix_cores = coopmat2_support ? "NV_coopmat2" : coopmat_support ? "KHR_coopmat" : "none";
4788
4789 std::string device_name = props2.properties.deviceName.data();
4790 GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | bf16: %d | warp size: %zu | shared memory: %d | int dot: %d | matrix cores: %s\n",
4791 idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16, bf16, subgroup_size,
4792 props2.properties.limits.maxComputeSharedMemorySize, integer_dot_product, matrix_cores.c_str());
4793
4794 if (props2.properties.deviceType == vk::PhysicalDeviceType::eCpu) {
4795 GGML_LOG_DEBUG("ggml_vulkan: Warning: Device type is CPU. This is probably not the device you want.\n");
4796 }
4797}
4798
4799static bool ggml_vk_instance_validation_ext_available();
4800static bool ggml_vk_instance_portability_enumeration_ext_available(const std::vector<vk::ExtensionProperties>& instance_extensions);
4801static bool ggml_vk_instance_debug_utils_ext_available(const std::vector<vk::ExtensionProperties> & instance_extensions);
4802static bool ggml_vk_device_is_supported(const vk::PhysicalDevice & vkdev);
4803
4804static DispatchLoaderDynamic ggml_vk_default_dispatcher_instance;
4805DispatchLoaderDynamic & ggml_vk_default_dispatcher() {
4806 return ggml_vk_default_dispatcher_instance;
4807}
4808
4809static void ggml_vk_instance_init() {
4810 if (vk_instance_initialized) {
4811 return;
4812 }
4813 VK_LOG_DEBUG("ggml_vk_instance_init()");
4814
4815 // See https://github.com/KhronosGroup/Vulkan-Hpp?tab=readme-ov-file#extensions--per-device-function-pointers-
4816 ggml_vk_default_dispatcher_instance.init(getInstanceProcAddr: vkGetInstanceProcAddr);
4817
4818 uint32_t api_version = vk::enumerateInstanceVersion();
4819
4820 if (api_version < VK_API_VERSION_1_2) {
4821 std::cerr << "ggml_vulkan: Error: Vulkan 1.2 required." << std::endl;
4822 throw vk::SystemError(vk::Result::eErrorFeatureNotPresent, "Vulkan 1.2 required");
4823 }
4824
4825 vk::ApplicationInfo app_info{ "ggml-vulkan", 1, nullptr, 0, api_version };
4826
4827 const std::vector<vk::ExtensionProperties> instance_extensions = vk::enumerateInstanceExtensionProperties();
4828 const bool validation_ext = ggml_vk_instance_validation_ext_available();
4829#ifdef __APPLE__
4830 const bool portability_enumeration_ext = ggml_vk_instance_portability_enumeration_ext_available(instance_extensions);
4831#endif
4832 const bool debug_utils_ext = ggml_vk_instance_debug_utils_ext_available(instance_extensions) && getenv(name: "GGML_VK_DEBUG_MARKERS") != nullptr;
4833 std::vector<const char*> layers;
4834
4835 if (validation_ext) {
4836 layers.push_back(x: "VK_LAYER_KHRONOS_validation");
4837 }
4838 std::vector<const char*> extensions;
4839 if (validation_ext) {
4840 extensions.push_back(x: "VK_EXT_validation_features");
4841 }
4842#ifdef __APPLE__
4843 if (portability_enumeration_ext) {
4844 extensions.push_back("VK_KHR_portability_enumeration");
4845 }
4846#endif
4847 if (debug_utils_ext) {
4848 extensions.push_back(x: "VK_EXT_debug_utils");
4849 }
4850 vk::InstanceCreateInfo instance_create_info(vk::InstanceCreateFlags{}, &app_info, layers, extensions);
4851#ifdef __APPLE__
4852 if (portability_enumeration_ext) {
4853 instance_create_info.flags |= vk::InstanceCreateFlagBits::eEnumeratePortabilityKHR;
4854 }
4855#endif
4856
4857 std::vector<vk::ValidationFeatureEnableEXT> features_enable;
4858 vk::ValidationFeaturesEXT validation_features;
4859
4860 if (validation_ext) {
4861 features_enable = { vk::ValidationFeatureEnableEXT::eBestPractices };
4862 validation_features = {
4863 features_enable,
4864 {},
4865 };
4866 validation_features.setPNext(nullptr);
4867 instance_create_info.setPNext(&validation_features);
4868 GGML_LOG_DEBUG("ggml_vulkan: Validation layers enabled\n");
4869 }
4870 vk_instance.instance = vk::createInstance(createInfo: instance_create_info);
4871 vk_instance_initialized = true;
4872
4873 if (debug_utils_ext) {
4874 vk_instance.debug_utils_support = true;
4875 vk_instance.pfn_vkSetDebugUtilsObjectNameEXT = (PFN_vkSetDebugUtilsObjectNameEXT) vkGetInstanceProcAddr(instance: vk_instance.instance, pName: "vkSetDebugUtilsObjectNameEXT");
4876 vk_instance.pfn_vkQueueBeginDebugUtilsLabelEXT = (PFN_vkQueueBeginDebugUtilsLabelEXT) vkGetInstanceProcAddr(instance: vk_instance.instance, pName: "vkQueueBeginDebugUtilsLabelEXT");
4877 vk_instance.pfn_vkQueueEndDebugUtilsLabelEXT = (PFN_vkQueueEndDebugUtilsLabelEXT) vkGetInstanceProcAddr(instance: vk_instance.instance, pName: "vkQueueEndDebugUtilsLabelEXT");
4878 vk_instance.pfn_vkCmdBeginDebugUtilsLabelEXT = (PFN_vkCmdBeginDebugUtilsLabelEXT) vkGetInstanceProcAddr(instance: vk_instance.instance, pName: "vkCmdBeginDebugUtilsLabelEXT");
4879 vk_instance.pfn_vkCmdEndDebugUtilsLabelEXT = (PFN_vkCmdEndDebugUtilsLabelEXT) vkGetInstanceProcAddr(instance: vk_instance.instance, pName: "vkCmdEndDebugUtilsLabelEXT");
4880 vk_instance.pfn_vkCmdInsertDebugUtilsLabelEXT = (PFN_vkCmdInsertDebugUtilsLabelEXT) vkGetInstanceProcAddr(instance: vk_instance.instance, pName: "vkCmdInsertDebugUtilsLabelEXT");
4881 }
4882
4883 vk_perf_logger_enabled = getenv(name: "GGML_VK_PERF_LOGGER") != nullptr;
4884
4885 // See https://github.com/KhronosGroup/Vulkan-Hpp?tab=readme-ov-file#extensions--per-device-function-pointers-
4886 VULKAN_HPP_DEFAULT_DISPATCHER.init(instanceCpp: vk_instance.instance);
4887
4888 std::vector<vk::PhysicalDevice> devices = vk_instance.instance.enumeratePhysicalDevices();
4889
4890 // Emulate behavior of CUDA_VISIBLE_DEVICES for Vulkan
4891 char * devices_env = getenv(name: "GGML_VK_VISIBLE_DEVICES");
4892 if (devices_env != nullptr) {
4893 size_t num_available_devices = devices.size();
4894
4895 std::string devices(devices_env);
4896 std::replace(first: devices.begin(), last: devices.end(), old_value: ',', new_value: ' ');
4897
4898 std::stringstream ss(devices);
4899 size_t tmp;
4900 while (ss >> tmp) {
4901 if(tmp >= num_available_devices) {
4902 std::cerr << "ggml_vulkan: Invalid device index " << tmp << " in GGML_VK_VISIBLE_DEVICES." << std::endl;
4903 throw std::runtime_error("Invalid Vulkan device index");
4904 }
4905 vk_instance.device_indices.push_back(x: tmp);
4906 }
4907 } else {
4908 // If no vulkan devices are found, return early
4909 if (devices.empty()) {
4910 GGML_LOG_INFO("ggml_vulkan: No devices found.\n");
4911 return;
4912 }
4913
4914 // Default to using all dedicated GPUs
4915 for (size_t i = 0; i < devices.size(); i++) {
4916 vk::PhysicalDeviceProperties2 new_props;
4917 vk::PhysicalDeviceDriverProperties new_driver;
4918 vk::PhysicalDeviceIDProperties new_id;
4919 new_props.pNext = &new_driver;
4920 new_driver.pNext = &new_id;
4921 devices[i].getProperties2(pProperties: &new_props);
4922
4923 if ((new_props.properties.deviceType == vk::PhysicalDeviceType::eDiscreteGpu || new_props.properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu) && ggml_vk_device_is_supported(vkdev: devices[i])) {
4924 // Check if there are two physical devices corresponding to the same GPU
4925 auto old_device = std::find_if(
4926 first: vk_instance.device_indices.begin(),
4927 last: vk_instance.device_indices.end(),
4928 pred: [&devices, &new_id](const size_t k){
4929 vk::PhysicalDeviceProperties2 old_props;
4930 vk::PhysicalDeviceIDProperties old_id;
4931 old_props.pNext = &old_id;
4932 devices[k].getProperties2(pProperties: &old_props);
4933
4934 bool equals = std::equal(first1: std::begin(cont&: old_id.deviceUUID), last1: std::end(cont&: old_id.deviceUUID), first2: std::begin(cont&: new_id.deviceUUID));
4935 equals = equals || (
4936 old_id.deviceLUIDValid && new_id.deviceLUIDValid &&
4937 std::equal(first1: std::begin(cont&: old_id.deviceLUID), last1: std::end(cont&: old_id.deviceLUID), first2: std::begin(cont&: new_id.deviceLUID))
4938 );
4939
4940 return equals;
4941 }
4942 );
4943 if (old_device == vk_instance.device_indices.end()) {
4944 vk_instance.device_indices.push_back(x: i);
4945 } else {
4946 // There can be two physical devices corresponding to the same GPU if there are 2 different drivers
4947 // This can cause error when splitting layers aross the devices, need to keep only 1
4948 VK_LOG_DEBUG("Device " << i << " and device " << *old_device << " have the same deviceUUID");
4949
4950 vk::PhysicalDeviceProperties2 old_props;
4951 vk::PhysicalDeviceDriverProperties old_driver;
4952 old_props.pNext = &old_driver;
4953 devices[*old_device].getProperties2(pProperties: &old_props);
4954
4955 std::map<vk::DriverId, int> driver_priorities {};
4956 int old_priority = std::numeric_limits<int>::max();
4957 int new_priority = std::numeric_limits<int>::max();
4958
4959 // Check https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/VkDriverId.html for the list of driver id
4960 // Smaller number -> higher priority
4961 switch (old_props.properties.vendorID) {
4962 case VK_VENDOR_ID_AMD:
4963 driver_priorities[vk::DriverId::eMesaRadv] = 1;
4964 driver_priorities[vk::DriverId::eAmdOpenSource] = 2;
4965 driver_priorities[vk::DriverId::eAmdProprietary] = 3;
4966 break;
4967 case VK_VENDOR_ID_INTEL:
4968 driver_priorities[vk::DriverId::eIntelOpenSourceMESA] = 1;
4969 driver_priorities[vk::DriverId::eIntelProprietaryWindows] = 2;
4970 break;
4971 case VK_VENDOR_ID_NVIDIA:
4972 driver_priorities[vk::DriverId::eNvidiaProprietary] = 1;
4973#if defined(VK_API_VERSION_1_3) && VK_HEADER_VERSION >= 235
4974 driver_priorities[vk::DriverId::eMesaNvk] = 2;
4975#endif
4976 break;
4977 }
4978 driver_priorities[vk::DriverId::eMesaDozen] = 100;
4979
4980 if (driver_priorities.count(x: old_driver.driverID)) {
4981 old_priority = driver_priorities[old_driver.driverID];
4982 }
4983 if (driver_priorities.count(x: new_driver.driverID)) {
4984 new_priority = driver_priorities[new_driver.driverID];
4985 }
4986
4987 if (new_priority < old_priority) {
4988 auto r = std::remove(first: vk_instance.device_indices.begin(), last: vk_instance.device_indices.end(), value: *old_device);
4989 vk_instance.device_indices.erase(first: r, last: vk_instance.device_indices.end());
4990 vk_instance.device_indices.push_back(x: i);
4991
4992 VK_LOG_DEBUG("Prioritize device " << i << " driver " << new_driver.driverName << " over device " << *old_device << " driver " << old_driver.driverName);
4993 }
4994 else {
4995 VK_LOG_DEBUG("Prioritize device " << *old_device << " driver " << old_driver.driverName << " over device " << i << " driver " << new_driver.driverName << std::endl);
4996 }
4997 }
4998 }
4999 }
5000
5001 // If no GPUs found, fall back to the first non-CPU device.
5002 // If only CPU devices are available, return without devices.
5003 if (vk_instance.device_indices.empty()) {
5004 for (size_t i = 0; i < devices.size(); i++) {
5005 if (devices[i].getProperties().deviceType != vk::PhysicalDeviceType::eCpu) {
5006 vk_instance.device_indices.push_back(x: i);
5007 break;
5008 }
5009 }
5010 }
5011
5012 if (vk_instance.device_indices.empty()) {
5013 GGML_LOG_INFO("ggml_vulkan: No devices found.\n");
5014 return;
5015 }
5016 }
5017 GGML_LOG_DEBUG("ggml_vulkan: Found %zu Vulkan devices:\n", vk_instance.device_indices.size());
5018
5019 for (size_t i = 0; i < vk_instance.device_indices.size(); i++) {
5020 vk::PhysicalDevice vkdev = devices[vk_instance.device_indices[i]];
5021 std::vector<vk::ExtensionProperties> extensionprops = vkdev.enumerateDeviceExtensionProperties();
5022
5023 bool membudget_supported = false;
5024 for (const auto & ext : extensionprops) {
5025 if (strcmp(VK_EXT_MEMORY_BUDGET_EXTENSION_NAME, s2: ext.extensionName) == 0) {
5026 membudget_supported = true;
5027 break;
5028 }
5029 }
5030
5031 vk_instance.device_supports_membudget.push_back(x: membudget_supported);
5032
5033 ggml_vk_print_gpu_info(idx: i);
5034 }
5035}
5036
5037static void ggml_vk_init(ggml_backend_vk_context * ctx, size_t idx) {
5038 VK_LOG_DEBUG("ggml_vk_init(" << ctx->name << ", " << idx << ")");
5039 ggml_vk_instance_init();
5040 GGML_ASSERT(idx < vk_instance.device_indices.size());
5041
5042 ctx->name = GGML_VK_NAME + std::to_string(val: idx);
5043
5044 ctx->device = ggml_vk_get_device(idx);
5045
5046 ctx->semaphore_idx = 0;
5047 ctx->event_idx = 0;
5048
5049 ctx->prealloc_size_x = 0;
5050 ctx->prealloc_size_y = 0;
5051 ctx->prealloc_size_split_k = 0;
5052 ctx->prealloc_size_add_rms_partials = 0;
5053
5054 ctx->fence = ctx->device->device.createFence(createInfo: {});
5055 ctx->almost_ready_fence = ctx->device->device.createFence(createInfo: {});
5056
5057 ctx->compute_cmd_pool.init(device&: ctx->device, q_: &ctx->device->compute_queue);
5058 ctx->transfer_cmd_pool.init(device&: ctx->device, q_: &ctx->device->transfer_queue);
5059
5060#ifdef GGML_VULKAN_CHECK_RESULTS
5061 const char* skip_checks = getenv("GGML_VULKAN_SKIP_CHECKS");
5062 vk_skip_checks = (skip_checks == NULL ? 0 : atoi(skip_checks));
5063 const char* output_tensor = getenv("GGML_VULKAN_OUTPUT_TENSOR");
5064 vk_output_tensor = (output_tensor == NULL ? 0 : atoi(output_tensor));
5065#endif
5066}
5067
5068static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type type) {
5069 VK_LOG_DEBUG("ggml_vk_get_to_fp16()");
5070 switch (type) {
5071 case GGML_TYPE_F32:
5072 case GGML_TYPE_Q4_0:
5073 case GGML_TYPE_Q4_1:
5074 case GGML_TYPE_Q5_0:
5075 case GGML_TYPE_Q5_1:
5076 case GGML_TYPE_Q8_0:
5077 case GGML_TYPE_Q2_K:
5078 case GGML_TYPE_Q3_K:
5079 case GGML_TYPE_Q4_K:
5080 case GGML_TYPE_Q5_K:
5081 case GGML_TYPE_Q6_K:
5082 case GGML_TYPE_IQ1_S:
5083 case GGML_TYPE_IQ1_M:
5084 case GGML_TYPE_IQ2_XXS:
5085 case GGML_TYPE_IQ2_XS:
5086 case GGML_TYPE_IQ2_S:
5087 case GGML_TYPE_IQ3_XXS:
5088 case GGML_TYPE_IQ3_S:
5089 case GGML_TYPE_IQ4_XS:
5090 case GGML_TYPE_IQ4_NL:
5091 case GGML_TYPE_MXFP4:
5092 break;
5093 default:
5094 return nullptr;
5095 }
5096
5097 return ctx->device->pipeline_dequant[type];
5098}
5099
5100static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type, ggml_prec prec) {
5101 VK_LOG_DEBUG("ggml_vk_get_mul_mat_mat_pipeline(" << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ", " << prec << ")");
5102 if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
5103 return ctx->device->pipeline_matmul_f32;
5104 }
5105 if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) {
5106 return ctx->device->pipeline_matmul_f32_f16;
5107 }
5108 if (src0_type == GGML_TYPE_BF16 && src1_type == GGML_TYPE_BF16) {
5109 return ctx->device->pipeline_matmul_bf16;
5110 }
5111 if (prec == GGML_PREC_DEFAULT && ctx->device->fp16 && !(ctx->device->coopmat_support && !ctx->device->coopmat_acc_f16_support)) {
5112 if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
5113 return ctx->device->pipeline_matmul_f16_f32.f16acc;
5114 }
5115 if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
5116 return ctx->device->pipeline_matmul_f16.f16acc;
5117 }
5118 } else {
5119 if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
5120 return ctx->device->pipeline_matmul_f16_f32.f32acc;
5121 }
5122 if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
5123 return ctx->device->pipeline_matmul_f16.f32acc;
5124 }
5125 }
5126
5127 // MMQ
5128 if (src1_type == GGML_TYPE_Q8_1) {
5129 vk_matmul_pipeline pipelines = ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f32acc;
5130
5131 if (pipelines->is_empty()) {
5132 return nullptr;
5133 }
5134
5135 return pipelines;
5136 }
5137
5138 if (src1_type != GGML_TYPE_F32 && !ctx->device->coopmat2) {
5139 return nullptr;
5140 }
5141
5142 switch (src0_type) {
5143 case GGML_TYPE_Q4_0:
5144 case GGML_TYPE_Q4_1:
5145 case GGML_TYPE_Q5_0:
5146 case GGML_TYPE_Q5_1:
5147 case GGML_TYPE_Q8_0:
5148 case GGML_TYPE_Q2_K:
5149 case GGML_TYPE_Q3_K:
5150 case GGML_TYPE_Q4_K:
5151 case GGML_TYPE_Q5_K:
5152 case GGML_TYPE_Q6_K:
5153 case GGML_TYPE_IQ1_S:
5154 case GGML_TYPE_IQ1_M:
5155 case GGML_TYPE_IQ2_XXS:
5156 case GGML_TYPE_IQ2_XS:
5157 case GGML_TYPE_IQ2_S:
5158 case GGML_TYPE_IQ3_XXS:
5159 case GGML_TYPE_IQ3_S:
5160 case GGML_TYPE_IQ4_XS:
5161 case GGML_TYPE_IQ4_NL:
5162 case GGML_TYPE_MXFP4:
5163 break;
5164 default:
5165 return nullptr;
5166 }
5167
5168 if (ctx->device->coopmat2) {
5169 assert(src1_type == GGML_TYPE_F16);
5170 return prec == GGML_PREC_DEFAULT ? ctx->device->pipeline_dequant_mul_mat_mat_f16[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat_f16[src0_type].f32acc;
5171 }
5172 if (ctx->device->coopmat_support) {
5173 return (ctx->device->fp16 && ctx->device->coopmat_acc_f16_support && prec == GGML_PREC_DEFAULT) ? ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f32acc;
5174 }
5175 return (ctx->device->fp16 && prec == GGML_PREC_DEFAULT) ? ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f32acc;
5176}
5177
5178static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type, uint32_t num_cols, uint32_t m, uint32_t k) {
5179 VK_LOG_DEBUG("ggml_vk_get_dequantize_mul_mat_vec()");
5180 GGML_ASSERT(b_type == GGML_TYPE_F32 || b_type == GGML_TYPE_F16 || b_type == GGML_TYPE_Q8_1);
5181 GGML_ASSERT(num_cols >= 1 && num_cols <= mul_mat_vec_max_cols);
5182
5183 if (b_type == GGML_TYPE_Q8_1) {
5184 switch (a_type) {
5185 case GGML_TYPE_Q4_0:
5186 case GGML_TYPE_Q4_1:
5187 case GGML_TYPE_Q5_0:
5188 case GGML_TYPE_Q5_1:
5189 case GGML_TYPE_Q8_0:
5190 break;
5191 default:
5192 return nullptr;
5193 }
5194 }
5195
5196 switch (a_type) {
5197 case GGML_TYPE_F32:
5198 case GGML_TYPE_F16:
5199 case GGML_TYPE_BF16:
5200 case GGML_TYPE_Q4_0:
5201 case GGML_TYPE_Q4_1:
5202 case GGML_TYPE_Q5_0:
5203 case GGML_TYPE_Q5_1:
5204 case GGML_TYPE_Q8_0:
5205 case GGML_TYPE_Q2_K:
5206 case GGML_TYPE_Q3_K:
5207 case GGML_TYPE_Q4_K:
5208 case GGML_TYPE_Q5_K:
5209 case GGML_TYPE_Q6_K:
5210 case GGML_TYPE_IQ1_S:
5211 case GGML_TYPE_IQ1_M:
5212 case GGML_TYPE_IQ2_XXS:
5213 case GGML_TYPE_IQ2_XS:
5214 case GGML_TYPE_IQ2_S:
5215 case GGML_TYPE_IQ3_XXS:
5216 case GGML_TYPE_IQ3_S:
5217 case GGML_TYPE_IQ4_XS:
5218 case GGML_TYPE_IQ4_NL:
5219 case GGML_TYPE_MXFP4:
5220 break;
5221 default:
5222 return nullptr;
5223 }
5224
5225 // heuristic to choose workgroup size
5226 uint32_t dmmv_wg = DMMV_WG_SIZE_SUBGROUP;
5227 if ((ctx->device->vendor_id == VK_VENDOR_ID_NVIDIA && ctx->device->architecture != vk_device_architecture::NVIDIA_PRE_TURING) || ctx->device->vendor_id == VK_VENDOR_ID_INTEL) {
5228 // Prefer larger workgroups when M is small, to spread the work out more
5229 // and keep more SMs busy.
5230 // q6_k seems to prefer small workgroup size even for "medium" values of M.
5231 if (a_type == GGML_TYPE_Q6_K) {
5232 if (m < 4096 && k >= 1024) {
5233 dmmv_wg = DMMV_WG_SIZE_LARGE;
5234 }
5235 } else {
5236 if (m <= 8192 && k >= 1024) {
5237 dmmv_wg = DMMV_WG_SIZE_LARGE;
5238 }
5239 }
5240 }
5241
5242 if (b_type == GGML_TYPE_Q8_1) {
5243 if (ctx->device->vendor_id == VK_VENDOR_ID_INTEL) {
5244 dmmv_wg = DMMV_WG_SIZE_SUBGROUP;
5245 }
5246 return ctx->device->pipeline_dequant_mul_mat_vec_q8_1_f32[dmmv_wg][a_type][num_cols-1];
5247 }
5248
5249 return b_type == GGML_TYPE_F32 ? ctx->device->pipeline_dequant_mul_mat_vec_f32_f32[dmmv_wg][a_type][num_cols-1] : ctx->device->pipeline_dequant_mul_mat_vec_f16_f32[dmmv_wg][a_type][num_cols-1];
5250}
5251
5252static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type, ggml_prec prec) {
5253 VK_LOG_DEBUG("ggml_vk_get_mul_mat_mat_id_pipeline()");
5254 if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
5255 return ctx->device->pipeline_matmul_id_f32;
5256 }
5257 if (src0_type == GGML_TYPE_BF16 && src1_type == GGML_TYPE_BF16) {
5258 return ctx->device->pipeline_matmul_id_bf16;
5259 }
5260 if (prec == GGML_PREC_DEFAULT && ctx->device->fp16 && !(ctx->device->coopmat_support && !ctx->device->coopmat_acc_f16_support)) {
5261 if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
5262 return ctx->device->pipeline_matmul_id_f16_f32.f16acc;
5263 }
5264 if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
5265 return ctx->device->pipeline_matmul_id_f16.f16acc;
5266 }
5267 } else {
5268 if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
5269 return ctx->device->pipeline_matmul_id_f16_f32.f32acc;
5270 }
5271 if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
5272 return ctx->device->pipeline_matmul_id_f16.f32acc;
5273 }
5274 }
5275
5276 // MMQ
5277 if (src1_type == GGML_TYPE_Q8_1) {
5278 vk_matmul_pipeline pipelines = ctx->device->pipeline_dequant_mul_mat_mat_id_q8_1[src0_type].f32acc;
5279
5280 if (pipelines->is_empty()) {
5281 return nullptr;
5282 }
5283
5284 return pipelines;
5285 }
5286
5287 GGML_ASSERT(src1_type == GGML_TYPE_F32 || (ctx->device->coopmat2 && src1_type == GGML_TYPE_F16));
5288
5289 switch (src0_type) {
5290 case GGML_TYPE_Q4_0:
5291 case GGML_TYPE_Q4_1:
5292 case GGML_TYPE_Q5_0:
5293 case GGML_TYPE_Q5_1:
5294 case GGML_TYPE_Q8_0:
5295 case GGML_TYPE_Q2_K:
5296 case GGML_TYPE_Q3_K:
5297 case GGML_TYPE_Q4_K:
5298 case GGML_TYPE_Q5_K:
5299 case GGML_TYPE_Q6_K:
5300 case GGML_TYPE_IQ1_S:
5301 case GGML_TYPE_IQ1_M:
5302 case GGML_TYPE_IQ2_XXS:
5303 case GGML_TYPE_IQ2_XS:
5304 case GGML_TYPE_IQ2_S:
5305 case GGML_TYPE_IQ3_XXS:
5306 case GGML_TYPE_IQ3_S:
5307 case GGML_TYPE_IQ4_XS:
5308 case GGML_TYPE_IQ4_NL:
5309 case GGML_TYPE_MXFP4:
5310 break;
5311 default:
5312 return nullptr;
5313 }
5314
5315 vk_matmul_pipeline2& mmp = ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type];
5316 // XXX TODO 'prec' is not actually allowed in mul_mat_id.
5317 bool prefer_fp16acc = ctx->device->fp16 /*&& prec == GGML_PREC_DEFAULT*/;
5318 bool support_fp16acc = !mmp.f16acc->is_empty();
5319 bool support_fp32acc = !mmp.f32acc->is_empty();
5320
5321 if (support_fp16acc && (prefer_fp16acc || !support_fp32acc)) {
5322 return mmp.f16acc;
5323 } else {
5324 GGML_ASSERT(support_fp32acc);
5325 return mmp.f32acc;
5326 }
5327}
5328
5329static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type) {
5330 VK_LOG_DEBUG("ggml_vk_get_dequantize_mul_mat_vec_id()");
5331 GGML_ASSERT(b_type == GGML_TYPE_F32);
5332
5333 switch (a_type) {
5334 case GGML_TYPE_F32:
5335 case GGML_TYPE_F16:
5336 case GGML_TYPE_BF16:
5337 case GGML_TYPE_Q4_0:
5338 case GGML_TYPE_Q4_1:
5339 case GGML_TYPE_Q5_0:
5340 case GGML_TYPE_Q5_1:
5341 case GGML_TYPE_Q8_0:
5342 case GGML_TYPE_Q2_K:
5343 case GGML_TYPE_Q3_K:
5344 case GGML_TYPE_Q4_K:
5345 case GGML_TYPE_Q5_K:
5346 case GGML_TYPE_Q6_K:
5347 case GGML_TYPE_IQ1_S:
5348 case GGML_TYPE_IQ1_M:
5349 case GGML_TYPE_IQ2_XXS:
5350 case GGML_TYPE_IQ2_XS:
5351 case GGML_TYPE_IQ2_S:
5352 case GGML_TYPE_IQ3_XXS:
5353 case GGML_TYPE_IQ3_S:
5354 case GGML_TYPE_IQ4_XS:
5355 case GGML_TYPE_IQ4_NL:
5356 case GGML_TYPE_MXFP4:
5357 break;
5358 default:
5359 return nullptr;
5360 }
5361
5362 return ctx->device->pipeline_dequant_mul_mat_vec_id_f32[a_type];
5363}
5364
5365static void * ggml_vk_host_malloc(vk_device& device, size_t size) {
5366 VK_LOG_MEMORY("ggml_vk_host_malloc(" << size << ")");
5367 vk_buffer buf = ggml_vk_create_buffer(device, size,
5368 req_flags_list: {vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached,
5369 vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent});
5370
5371 if(!(buf->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible)) {
5372 fprintf(stderr, format: "WARNING: failed to allocate %.2f MB of pinned memory\n",
5373 size/1024.0/1024.0);
5374 device->device.freeMemory(memory: buf->device_memory);
5375 device->device.destroyBuffer(buffer: buf->buffer);
5376 return nullptr;
5377 }
5378
5379 std::lock_guard<std::recursive_mutex> guard(device->mutex);
5380 device->pinned_memory.push_back(x: std::make_tuple(args&: buf->ptr, args&: size, args&: buf));
5381
5382 return buf->ptr;
5383}
5384
5385static void ggml_vk_host_free(vk_device& device, void* ptr) {
5386 if (ptr == nullptr) {
5387 return;
5388 }
5389 VK_LOG_MEMORY("ggml_vk_host_free(" << ptr << ")");
5390 std::lock_guard<std::recursive_mutex> guard(device->mutex);
5391
5392 vk_buffer buf;
5393 size_t index;
5394 for (size_t i = 0; i < device->pinned_memory.size(); i++) {
5395 const uint8_t* addr = (const uint8_t*) std::get<0>(t&: device->pinned_memory[i]);
5396 const uint8_t* endr = addr + std::get<1>(t&: device->pinned_memory[i]);
5397 if (ptr >= addr && ptr < endr) {
5398 buf = std::get<2>(t&: device->pinned_memory[i]);
5399 index = i;
5400 break;
5401 }
5402 }
5403 if (buf == nullptr) {
5404 fprintf(stderr, format: "WARNING: failed to free pinned memory: memory not in map\n");
5405 return;
5406 }
5407
5408 ggml_vk_destroy_buffer(buf);
5409
5410 device->pinned_memory.erase(position: device->pinned_memory.begin() + index);
5411}
5412
5413static void ggml_vk_host_get(const vk_device& device, const void * ptr, vk_buffer& buf, size_t& buf_offset) {
5414 std::lock_guard<std::recursive_mutex> guard(device->mutex);
5415 buf = nullptr;
5416 buf_offset = 0;
5417 for (size_t i = 0; i < device->pinned_memory.size(); i++) {
5418 const uint8_t* addr = (const uint8_t*) std::get<0>(t&: device->pinned_memory[i]);
5419 const uint8_t* endr = addr + std::get<1>(t&: device->pinned_memory[i]);
5420 if (ptr >= addr && ptr < endr) {
5421 buf = std::get<2>(t&: device->pinned_memory[i]);
5422 buf_offset = ((const uint8_t *)ptr) - addr;
5423 break;
5424 }
5425 }
5426}
5427
5428static vk_subbuffer ggml_vk_tensor_subbuffer(
5429 const ggml_backend_vk_context * ctx, const ggml_tensor * tensor, bool allow_misalign = false) {
5430
5431 vk_buffer buffer = nullptr;
5432 size_t offset = 0;
5433 if (ctx->device->uma) {
5434 ggml_vk_host_get(device: ctx->device, ptr: tensor->data, buf&: buffer, buf_offset&: offset);
5435 }
5436 if (!buffer) {
5437 auto buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context;
5438 buffer = buf_ctx->dev_buffer;
5439 offset = vk_tensor_offset(tensor) + tensor->view_offs;
5440 }
5441 GGML_ASSERT(buffer != nullptr);
5442
5443 size_t size = ggml_nbytes(tensor);
5444
5445 size_t misalign_bytes = offset & (ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1);
5446 // The shader must support misaligned offsets when indexing into the buffer
5447 GGML_ASSERT(allow_misalign || misalign_bytes == 0);
5448 offset &= ~misalign_bytes;
5449 size += misalign_bytes;
5450
5451 return vk_subbuffer{.buffer: buffer, .offset: offset, .size: size};
5452}
5453
5454static vk_submission ggml_vk_begin_submission(vk_device& device, vk_command_pool& p, bool one_time = true) {
5455 vk_submission s;
5456 s.buffer = ggml_vk_create_cmd_buffer(device, p);
5457 if (one_time) {
5458 s.buffer.begin(beginInfo: { vk::CommandBufferUsageFlagBits::eOneTimeSubmit });
5459 } else {
5460 s.buffer.begin(beginInfo: { vk::CommandBufferUsageFlags{} });
5461 }
5462
5463 return s;
5464}
5465
5466template <typename T> size_t push_constant_size(const T &t) {
5467 static_assert(std::is_class<T>::value, "T must be a struct/class");
5468 GGML_UNUSED(t);
5469 return sizeof(T);
5470}
5471template <typename T> size_t push_constant_size(const std::vector<T> &t) {
5472 GGML_UNUSED(t);
5473 return sizeof(T) * t.size();
5474}
5475template <typename T, uint32_t N> size_t push_constant_size(const std::array<T, N> &t) {
5476 GGML_UNUSED(t);
5477 return sizeof(T) * N;
5478}
5479
5480template <typename T> const T *push_constant_data(const T &t) {
5481 static_assert(std::is_class<T>::value, "T must be a struct/class");
5482 return &t;
5483}
5484template <typename T> const T *push_constant_data(const std::vector<T> &t) {
5485 return t.data();
5486}
5487template <typename T, uint32_t N> const T *push_constant_data(const std::array<T, N> &t) {
5488 return t.data();
5489}
5490
5491template <typename T>
5492static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context& subctx, vk_pipeline& pipeline, std::initializer_list<vk::DescriptorBufferInfo> const& descriptor_buffer_infos, const T &push_constants, std::array<uint32_t, 3> elements) {
5493 const uint32_t wg0 = CEIL_DIV(elements[0], pipeline->wg_denoms[0]);
5494 const uint32_t wg1 = CEIL_DIV(elements[1], pipeline->wg_denoms[1]);
5495 const uint32_t wg2 = CEIL_DIV(elements[2], pipeline->wg_denoms[2]);
5496 VK_LOG_DEBUG("ggml_vk_dispatch_pipeline(" << pipeline->name << ", {";
5497 for (auto& buffer : descriptor_buffer_infos) {
5498 std::cerr << "(" << buffer.buffer << ", " << buffer.offset << ", " << buffer.range << "), ";
5499 }
5500 std::cerr << "}, (" << wg0 << "," << wg1 << "," << wg2 << "))");
5501 GGML_ASSERT(ctx->descriptor_set_idx < ctx->descriptor_sets.size());
5502 GGML_ASSERT(descriptor_buffer_infos.size() <= MAX_PARAMETER_COUNT);
5503 GGML_ASSERT(pipeline->parameter_count == descriptor_buffer_infos.size());
5504
5505 vk::DescriptorSet& descriptor_set = ctx->descriptor_sets[ctx->descriptor_set_idx++];
5506 vk::WriteDescriptorSet write_descriptor_set{ descriptor_set, 0, 0, pipeline->parameter_count, vk::DescriptorType::eStorageBuffer, nullptr, descriptor_buffer_infos.begin() };
5507 ctx->device->device.updateDescriptorSets(descriptorWrites: { write_descriptor_set }, descriptorCopies: {});
5508
5509 subctx->s->buffer.pushConstants(pipeline->layout, vk::ShaderStageFlagBits::eCompute, 0, push_constant_size(push_constants), push_constant_data(push_constants));
5510 subctx->s->buffer.bindPipeline(pipelineBindPoint: vk::PipelineBindPoint::eCompute, pipeline: pipeline->pipeline);
5511 subctx->s->buffer.bindDescriptorSets(pipelineBindPoint: vk::PipelineBindPoint::eCompute,
5512 layout: pipeline->layout,
5513 firstSet: 0,
5514 descriptorSets: { descriptor_set },
5515 dynamicOffsets: {});
5516 subctx->s->buffer.dispatch(groupCountX: wg0, groupCountY: wg1, groupCountZ: wg2);
5517}
5518
5519static void ggml_vk_end_submission(vk_submission& s, std::vector<vk_semaphore> wait_semaphores, std::vector<vk_semaphore> signal_semaphores) {
5520 s.buffer.end();
5521
5522 s.wait_semaphores = std::move(wait_semaphores);
5523 s.signal_semaphores = std::move(signal_semaphores);
5524}
5525
5526static void ggml_vk_ctx_end(vk_context& ctx) {
5527 VK_LOG_DEBUG("ggml_vk_ctx_end(" << ctx << ", " << ctx->seqs.size() << ")");
5528 if (ctx->s == nullptr) {
5529 return;
5530 }
5531
5532 ctx->s->buffer.end();
5533 ctx->s = nullptr;
5534}
5535
5536static void ggml_vk_ctx_begin(vk_device& device, vk_context& subctx) {
5537 VK_LOG_DEBUG("ggml_vk_ctx_begin(" << device->name << ")");
5538 if (subctx->s != nullptr) {
5539 ggml_vk_ctx_end(ctx&: subctx);
5540 }
5541
5542 subctx->seqs.push_back(x: { ggml_vk_begin_submission(device, p&: *subctx->p) });
5543 subctx->s = subctx->seqs[subctx->seqs.size() - 1].data();
5544}
5545
5546static size_t ggml_vk_align_size(size_t width, size_t align) {
5547 VK_LOG_DEBUG("ggml_vk_align_size(" << width << ", " << align << ")");
5548 return CEIL_DIV(width, align) * align;
5549}
5550
5551static void deferred_memcpy(void * dst, const void * src, size_t size, std::vector<vk_staging_memcpy>* memcpys = nullptr) {
5552 if (memcpys == nullptr) {
5553 memcpy(dest: dst, src: src, n: size);
5554 } else {
5555 memcpys->emplace_back(args&: dst, args&: src, args&: size);
5556 }
5557}
5558
5559static void deferred_memset(void * dst, uint32_t val, size_t size, std::vector<vk_staging_memset>* memsets = nullptr) {
5560 if (memsets == nullptr) {
5561 memset(s: dst, c: val, n: size);
5562 } else {
5563 memsets->emplace_back(args&: dst, args&: val, args&: size);
5564 }
5565}
5566
5567static void ggml_vk_ensure_sync_staging_buffer(vk_device& device, size_t size) {
5568 if (device->sync_staging == nullptr || device->sync_staging->size < size) {
5569 VK_LOG_MEMORY("ggml_vk_ensure_sync_staging_buffer(" << size << ")");
5570 ggml_vk_destroy_buffer(buf&: device->sync_staging);
5571 device->sync_staging = ggml_vk_create_buffer_check(device, size,
5572 req_flags: vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached,
5573 fallback_flags: vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
5574 }
5575}
5576
5577static void ggml_vk_buffer_write_nc_async(ggml_backend_vk_context * ctx, vk_context& subctx, vk_buffer& dst, size_t offset, const ggml_tensor * tensor, bool sync_staging = false) {
5578 VK_LOG_DEBUG("ggml_vk_buffer_write_nc_async(" << tensor << ")");
5579 GGML_ASSERT(!ggml_is_contiguous(tensor));
5580 // Buffer is already mapped
5581 if(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) {
5582 std::cerr << "ggml_vulkan: buffer_write_nc_async dst buffer is host_visible. Use synchronous write." << std::endl;
5583 GGML_ABORT("fatal error");
5584 }
5585 // Check if src is pinned memory
5586 vk_buffer buf = nullptr;
5587 size_t buf_offset = 0;
5588 ggml_vk_host_get(device: ctx->device, ptr: tensor->data, buf, buf_offset);
5589
5590 const uint64_t ne0 = tensor->ne[0];
5591 const uint64_t ne1 = tensor->ne[1];
5592 const uint64_t ne2 = tensor->ne[2];
5593 const uint64_t ne3 = tensor->ne[3];
5594 const uint64_t nb0 = tensor->nb[0];
5595 const uint64_t nb1 = tensor->nb[1];
5596 const uint64_t nb2 = tensor->nb[2];
5597 const uint64_t nb3 = tensor->nb[3];
5598 const ggml_type type = tensor->type;
5599 const uint64_t ts = ggml_type_size(type);
5600 const uint64_t bs = ggml_blck_size(type);
5601
5602 const uint64_t dstnb0 = ts;
5603 const uint64_t dstnb1 = dstnb0*(ne0/bs);
5604 const uint64_t dstnb2 = dstnb1*ne1;
5605 const uint64_t dstnb3 = dstnb2*ne2;
5606
5607 const uint64_t ne = ggml_nelements(tensor);
5608
5609 if (buf != nullptr) {
5610 // Memory is pinned, use as staging buffer
5611 std::vector<vk::BufferCopy> slices;
5612
5613 for (uint64_t i3 = 0; i3 < ne3; i3++) {
5614 for (uint64_t i2 = 0; i2 < ne2; i2++) {
5615 // Find longest contiguous slice
5616 if (ne1*nb1 == dstnb2) {
5617 slices.push_back(x: { buf_offset + i3*nb3 + i2*nb2, offset + i3*dstnb3 + i2*dstnb2, dstnb2 });
5618 } else {
5619 for (uint64_t i1 = 0; i1 < ne1; i1++) {
5620 if (ne0*nb0/bs == dstnb1) {
5621 slices.push_back(x: { buf_offset + i3*nb3 + i2*nb2 + i1*nb1, offset + i3*dstnb3 + i2*dstnb2 + i1*dstnb1, dstnb1 });
5622 } else {
5623 const uint64_t s_off = buf_offset + i3*nb3 + i2*nb2 + i1*nb1;
5624 const uint64_t d_off = offset + i3*dstnb3 + i2*dstnb2 + i1*dstnb1;
5625 for (uint64_t i0 = 0; i0 < ne0; i0++) {
5626 slices.push_back(x: { s_off + i1*nb0, d_off + i0*dstnb0, dstnb0 });
5627 }
5628 }
5629 }
5630 }
5631 }
5632 }
5633
5634 ggml_vk_sync_buffers(ctx, subctx);
5635 subctx->s->buffer.copyBuffer(srcBuffer: buf->buffer, dstBuffer: dst->buffer, regions: slices);
5636 return;
5637 }
5638
5639 if (!sync_staging) {
5640 GGML_ABORT("Asynchronous write to non-pinned memory not supported");
5641 }
5642
5643 // Staging buffer required
5644 vk_buffer& staging = ctx->device->sync_staging;
5645 const uint64_t copy_size = ts*ne/bs;
5646 ggml_vk_ensure_sync_staging_buffer(device&: ctx->device, size: copy_size);
5647 VkBufferCopy buf_copy{ .srcOffset: 0, .dstOffset: offset, .size: copy_size };
5648
5649 ggml_vk_sync_buffers(ctx, subctx);
5650 vkCmdCopyBuffer(commandBuffer: subctx->s->buffer, srcBuffer: (VkBuffer)staging->buffer, dstBuffer: (VkBuffer)dst->buffer, regionCount: 1, pRegions: &buf_copy);
5651
5652 for (uint64_t i3 = 0; i3 < ne3; i3++) {
5653 for (uint64_t i2 = 0; i2 < ne2; i2++) {
5654 // Find longest contiguous slice
5655 if (ne1*nb1 == dstnb2) {
5656 deferred_memcpy(dst: (uint8_t *)staging->ptr + i3*dstnb3 + i2*dstnb2, src: (const uint8_t *) tensor->data + buf_offset + i3*nb3 + i2*nb2, size: dstnb2, memcpys: &subctx->in_memcpys);
5657 } else {
5658 for (uint64_t i1 = 0; i1 < ne1; i1++) {
5659 if (ne0*nb0/bs == dstnb1) {
5660 deferred_memcpy(dst: (uint8_t *)staging->ptr + i3*dstnb3 + i2*dstnb2 + i1*dstnb1, src: (const uint8_t *) tensor->data + buf_offset + i3*nb3 + i2*nb2 + i1*nb1, size: dstnb1, memcpys: &subctx->in_memcpys);
5661 } else {
5662 const uint64_t s_off = buf_offset + i3*nb3 + i2*nb2 + i1*nb1;
5663 const uint64_t d_off = i3*dstnb3 + i2*dstnb2 + i1*dstnb1;
5664 for (uint64_t i0 = 0; i0 < ne0; i0++) {
5665 deferred_memcpy(dst: (uint8_t *)staging->ptr + d_off + i0*dstnb0, src: (const uint8_t *) tensor->data + s_off + i0*nb0, size: dstnb0, memcpys: &subctx->in_memcpys);
5666 }
5667 }
5668 }
5669 }
5670 }
5671 }
5672}
5673
5674static void ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t width, size_t height, bool sync_staging = false) {
5675 VK_LOG_DEBUG("ggml_vk_buffer_write_2d_async(" << width << ", " << height << ")");
5676 // Buffer is already mapped
5677 if(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) {
5678 std::cerr << "ggml_vulkan: buffer_write_async dst buffer is host_visible. Use synchronous write." << std::endl;
5679 GGML_ABORT("fatal error");
5680 }
5681 // Check if src is pinned memory
5682 vk_buffer buf = nullptr;
5683 size_t buf_offset = 0;
5684 ggml_vk_host_get(device: dst->device, ptr: src, buf, buf_offset);
5685
5686 if (buf != nullptr) {
5687 // Memory is pinned, use as staging buffer
5688 std::vector<vk::BufferCopy> slices(1);
5689 if (width == spitch) {
5690 // Only do single write if stride is equal
5691 slices[0].srcOffset = buf_offset;
5692 slices[0].dstOffset = offset;
5693 slices[0].size = width * height;
5694 } else {
5695 slices.resize(new_size: height);
5696 for (size_t i = 0; i < height; i++) {
5697 slices[i].srcOffset = buf_offset + i * spitch;
5698 slices[i].dstOffset = offset + i * width;
5699 slices[i].size = width;
5700 }
5701 }
5702
5703 ggml_vk_sync_buffers(ctx: nullptr, subctx);
5704 subctx->s->buffer.copyBuffer(srcBuffer: buf->buffer, dstBuffer: dst->buffer, regions: slices);
5705 return;
5706 }
5707 VK_LOG_DEBUG("STAGING");
5708
5709 if (!sync_staging) {
5710 GGML_ABORT("Asynchronous write to non-pinned memory not supported");
5711 }
5712
5713 // Staging buffer required
5714 const size_t copy_size = width*height;
5715 ggml_vk_ensure_sync_staging_buffer(device&: dst->device, size: copy_size);
5716
5717 vk_buffer& staging_buffer = dst->device->sync_staging;
5718
5719 VkBufferCopy buf_copy = {
5720 .srcOffset: 0,
5721 .dstOffset: offset,
5722 .size: copy_size};
5723
5724 ggml_vk_sync_buffers(ctx: nullptr, subctx);
5725 vkCmdCopyBuffer(commandBuffer: subctx->s->buffer, srcBuffer: (VkBuffer)staging_buffer->buffer, dstBuffer: (VkBuffer)dst->buffer, regionCount: 1, pRegions: &buf_copy);
5726
5727 if (width == spitch) {
5728 deferred_memcpy(dst: (uint8_t *)staging_buffer->ptr, src, size: width * height, memcpys: &subctx->in_memcpys);
5729 } else {
5730 for (size_t i = 0; i < height; i++) {
5731 deferred_memcpy(dst: (uint8_t *)staging_buffer->ptr + i * width, src: (const uint8_t *) src + i * spitch, size: width, memcpys: &subctx->in_memcpys);
5732 }
5733 }
5734}
5735
5736static void ggml_vk_buffer_write_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t size, bool sync_staging = false) {
5737 VK_LOG_DEBUG("ggml_vk_buffer_write_async(" << size << ")");
5738 return ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, spitch: size, width: size, height: 1, sync_staging);
5739}
5740
5741static void ggml_vk_buffer_write_2d(vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t width, size_t height) {
5742 VK_LOG_DEBUG("ggml_vk_buffer_write_2d(" << width << ", " << height << ")");
5743 // Buffer is already mapped
5744 if(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) {
5745 GGML_ASSERT(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostCoherent);
5746
5747 for (size_t i = 0; i < height; i++) {
5748 memcpy(dest: (uint8_t *)dst->ptr + offset + i * width, src: (const uint8_t *) src + i * spitch, n: width);
5749 }
5750 } else {
5751 std::lock_guard<std::recursive_mutex> guard(dst->device->mutex);
5752
5753 vk_context subctx = ggml_vk_create_temporary_context(p&: dst->device->transfer_queue.cmd_pool);
5754 ggml_vk_ctx_begin(device&: dst->device, subctx);
5755 ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, spitch, width, height, sync_staging: true);
5756 ggml_vk_ctx_end(ctx&: subctx);
5757
5758 for (auto& cpy : subctx->in_memcpys) {
5759 memcpy(dest: cpy.dst, src: cpy.src, n: cpy.n);
5760 }
5761
5762 for (auto& mset : subctx->memsets) {
5763 memset(s: mset.dst, c: mset.val, n: mset.n);
5764 }
5765
5766 ggml_vk_submit(ctx&: subctx, fence: dst->device->fence);
5767 VK_CHECK(dst->device->device.waitForFences({ dst->device->fence }, true, UINT64_MAX), "vk_buffer_write_2d waitForFences");
5768 dst->device->device.resetFences(fences: { dst->device->fence });
5769 ggml_vk_queue_command_pools_cleanup(device&: dst->device);
5770 }
5771}
5772
5773static void ggml_vk_buffer_write(vk_buffer& dst, size_t offset, const void * src, size_t size) {
5774 VK_LOG_DEBUG("ggml_vk_buffer_write(" << size << ")");
5775 ggml_vk_buffer_write_2d(dst, offset, src, spitch: 0, width: size, height: 1);
5776}
5777
5778static void ggml_vk_buffer_read_2d_async(vk_context subctx, vk_buffer& src, size_t offset, void * dst, size_t spitch, size_t dpitch, size_t width, size_t height, bool sync_staging = false) {
5779 VK_LOG_DEBUG("ggml_vk_buffer_read_2d_async(offset=" << offset << ", width=" << width << ", height=" << height << ")");
5780 GGML_ASSERT(width > 0);
5781 GGML_ASSERT(height > 0);
5782 GGML_ASSERT(src != nullptr);
5783
5784 // TODO: staging_offset is not used
5785
5786 // Check if dst is pinned memory
5787 vk_buffer buf = nullptr;
5788 size_t buf_offset = 0;
5789 ggml_vk_host_get(device: src->device, ptr: dst, buf, buf_offset);
5790
5791 std::vector<vk::BufferCopy> slices(1);
5792 if (width == spitch && width == dpitch) {
5793 // Only do single write if stride is equal
5794 slices[0].srcOffset = offset;
5795 slices[0].dstOffset = buf_offset;
5796 slices[0].size = width * height;
5797 } else {
5798 slices.resize(new_size: height);
5799 for (size_t i = 0; i < height; i++) {
5800 slices[i].srcOffset = offset + i * spitch;
5801 slices[i].dstOffset = buf_offset + i * dpitch;
5802 slices[i].size = width;
5803 }
5804 }
5805
5806 if (buf != nullptr) {
5807 // Memory is pinned, use as staging buffer
5808 ggml_vk_sync_buffers(ctx: nullptr, subctx);
5809 subctx->s->buffer.copyBuffer(srcBuffer: src->buffer, dstBuffer: buf->buffer, regions: slices);
5810
5811 return;
5812 }
5813 VK_LOG_DEBUG("STAGING");
5814
5815 if (!sync_staging) {
5816 GGML_ABORT("Asynchronous read from non-pinned memory not supported");
5817 }
5818
5819 // Fall back to staging buffer
5820 const size_t copy_size = dpitch * height;
5821 ggml_vk_ensure_sync_staging_buffer(device&: src->device, size: copy_size);
5822
5823 vk_buffer& staging_buffer = src->device->sync_staging;
5824
5825 ggml_vk_sync_buffers(ctx: nullptr, subctx);
5826 subctx->s->buffer.copyBuffer(srcBuffer: src->buffer, dstBuffer: staging_buffer->buffer, regions: slices);
5827
5828 deferred_memcpy(dst, src: staging_buffer->ptr, size: copy_size, memcpys: &subctx->out_memcpys);
5829}
5830
5831static void ggml_vk_buffer_read_async(vk_context subctx, vk_buffer& src, size_t offset, void * dst, size_t size, bool sync_staging = false) {
5832 return ggml_vk_buffer_read_2d_async(subctx, src, offset, dst, spitch: size, dpitch: size, width: size, height: 1, sync_staging);
5833}
5834
5835static void ggml_vk_buffer_read(vk_buffer& src, size_t offset, void * dst, size_t size) {
5836 VK_LOG_DEBUG("ggml_vk_buffer_read(" << src->buffer << ", " << offset << ", " << size << ")");
5837
5838 // If the device is not an UMA device the memory is host-accessible through rebar. While writing
5839 // through PCIe is sufficient fast reading back data from PCIe is slower than going through
5840 // the HW device to host copy path.
5841 if(src->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible && src->device->uma) {
5842 GGML_ASSERT(src->memory_property_flags & vk::MemoryPropertyFlagBits::eHostCoherent);
5843
5844 memcpy(dest: dst, src: (uint8_t *) src->ptr + offset, n: size);
5845 } else {
5846 std::lock_guard<std::recursive_mutex> guard(src->device->mutex);
5847
5848 vk_context subctx = ggml_vk_create_temporary_context(p&: src->device->transfer_queue.cmd_pool);
5849 ggml_vk_ctx_begin(device&: src->device, subctx);
5850 ggml_vk_buffer_read_async(subctx, src, offset, dst, size, sync_staging: true);
5851 ggml_vk_ctx_end(ctx&: subctx);
5852
5853 ggml_vk_submit(ctx&: subctx, fence: src->device->fence);
5854 VK_CHECK(src->device->device.waitForFences({ src->device->fence }, true, UINT64_MAX), "vk_buffer_read waitForFences");
5855 src->device->device.resetFences(fences: { src->device->fence });
5856 ggml_vk_queue_command_pools_cleanup(device&: src->device);
5857
5858 for (auto& cpy : subctx->out_memcpys) {
5859 memcpy(dest: cpy.dst, src: cpy.src, n: cpy.n);
5860 }
5861 }
5862}
5863
5864static void ggml_vk_buffer_copy_async(vk_context& ctx, vk_buffer& dst, size_t dst_offset, vk_buffer& src, size_t src_offset, size_t size) {
5865 VK_LOG_DEBUG("ggml_vk_buffer_copy_async(" << size << ")");
5866 // Make sure both buffers are on same device
5867 GGML_ASSERT(src->device == dst->device);
5868
5869 VkBufferCopy bc{ .srcOffset: src_offset, .dstOffset: dst_offset, .size: size };
5870
5871 vkCmdCopyBuffer(commandBuffer: ctx->s->buffer, srcBuffer: (VkBuffer)src->buffer, dstBuffer: (VkBuffer)dst->buffer, regionCount: 1, pRegions: &bc);
5872}
5873
5874static void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& src, size_t src_offset, size_t size) {
5875 if (src->device == dst->device) {
5876 std::lock_guard<std::recursive_mutex> guard(src->device->mutex);
5877 VK_LOG_DEBUG("ggml_vk_buffer_copy(SINGLE_DEVICE, " << size << ")");
5878 // Copy within the device
5879 vk_context subctx = ggml_vk_create_temporary_context(p&: src->device->transfer_queue.cmd_pool);
5880 ggml_vk_ctx_begin(device&: src->device, subctx);
5881 ggml_vk_buffer_copy_async(ctx&: subctx, dst, dst_offset, src, src_offset, size);
5882 ggml_vk_ctx_end(ctx&: subctx);
5883 ggml_vk_submit(ctx&: subctx, fence: src->device->fence);
5884 VK_CHECK(src->device->device.waitForFences({ src->device->fence }, true, UINT64_MAX), "vk_buffer_copy waitForFences");
5885 src->device->device.resetFences(fences: { src->device->fence });
5886 ggml_vk_queue_command_pools_cleanup(device&: src->device);
5887 } else {
5888 VK_LOG_DEBUG("ggml_vk_buffer_copy(MULTI_DEVICE, " << size << ")");
5889 // Copy device to device
5890 ggml_vk_ensure_sync_staging_buffer(device&: src->device, size);
5891
5892 // Copy to src staging buffer
5893 ggml_vk_buffer_copy(dst&: src->device->sync_staging, dst_offset: 0, src, src_offset, size);
5894 // Copy to dst buffer
5895 ggml_vk_buffer_write_2d(dst, offset: dst_offset, src: src->device->sync_staging->ptr, spitch: 0, width: size, height: 1);
5896 }
5897}
5898
5899static void ggml_vk_buffer_memset_async(vk_context& ctx, vk_buffer& dst, size_t offset, uint32_t c, size_t size) {
5900 VK_LOG_DEBUG("ggml_vk_buffer_memset_async(" << offset << ", " << c << ", " << size << ")");
5901
5902 if (dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible &&
5903 dst->device->uma) {
5904 deferred_memset(dst: (uint8_t*)dst->ptr + offset, val: c, size, memsets: &ctx->memsets);
5905 return;
5906 }
5907
5908 // Fall back to GPU fillBuffer for non-UMA or non-host-visible buffers
5909 ctx->s->buffer.fillBuffer(dstBuffer: dst->buffer, dstOffset: offset, size, data: c);
5910}
5911
5912static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, size_t size) {
5913 VK_LOG_DEBUG("ggml_vk_buffer_memset(" << offset << ", " << c << ", " << size << ")");
5914
5915 if (dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible &&
5916 dst->device->uma) {
5917 memset(s: (uint8_t*)dst->ptr + offset, c: c, n: size);
5918 return;
5919 }
5920
5921 std::lock_guard<std::recursive_mutex> guard(dst->device->mutex);
5922 vk_context subctx = ggml_vk_create_temporary_context(p&: dst->device->transfer_queue.cmd_pool);
5923 ggml_vk_ctx_begin(device&: dst->device, subctx);
5924 subctx->s->buffer.fillBuffer(dstBuffer: dst->buffer, dstOffset: offset, size, data: c);
5925 ggml_vk_ctx_end(ctx&: subctx);
5926
5927 ggml_vk_submit(ctx&: subctx, fence: dst->device->fence);
5928 VK_CHECK(dst->device->device.waitForFences({ dst->device->fence }, true, UINT64_MAX), "vk_memset waitForFences");
5929 dst->device->device.resetFences(fences: { dst->device->fence });
5930 ggml_vk_queue_command_pools_cleanup(device&: dst->device);
5931}
5932
5933static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, uint32_t m, uint32_t n, uint32_t k, bool disable_split_k, const vk_pipeline& pipeline) {
5934 VK_LOG_DEBUG("ggml_vk_guess_split_k(" << m << ", " << n << ", " << k << ", " << disable_split_k << ")");
5935
5936 if (disable_split_k) {
5937 return 1;
5938 }
5939
5940 uint32_t split_k = 1;
5941 if (ctx->device->shader_core_count != 0 && m >= pipeline->wg_denoms[0] && n >= pipeline->wg_denoms[1]) {
5942 // If k is 'large' and the SMs will fill less than halfway, use split_k.
5943 uint32_t m_tiles = CEIL_DIV(m, pipeline->wg_denoms[0]);
5944 uint32_t n_tiles = CEIL_DIV(n, pipeline->wg_denoms[1]);
5945
5946 if (k >= 2048) {
5947 if (m_tiles * n_tiles <= ctx->device->shader_core_count / 2) {
5948 split_k = ctx->device->shader_core_count / (m_tiles * n_tiles);
5949 } else if (m_tiles * n_tiles <= ctx->device->shader_core_count * 2 / 3) {
5950 split_k = 3;
5951 }
5952 // Cap the split at 8x. Unless k is huge this is a lot of overhead.
5953 split_k = std::min(a: split_k, b: 8u);
5954
5955 // ggml_vk_matmul will align the splits to be a multiple of 256.
5956 // If this rounded up size would cause the last split to be empty,
5957 // then reduce the split count.
5958 while (true) {
5959 if (split_k == 1) {
5960 break;
5961 }
5962 uint32_t k_split = CEIL_DIV(k, split_k);
5963 k_split = ROUNDUP_POW2(k_split, 256);
5964 if (k_split * (split_k - 1) < k) {
5965 break;
5966 }
5967 split_k--;
5968 }
5969 }
5970 }
5971
5972 return split_k;
5973}
5974
5975static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type, ggml_type src1_type) {
5976 VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")");
5977
5978 if (ctx->device->coopmat2) {
5979 const uint32_t shader_core_count = ctx->device->shader_core_count;
5980 const uint32_t tiles_l = CEIL_DIV(m, mmp->a_l->wg_denoms[0]) * CEIL_DIV(n, mmp->a_l->wg_denoms[1]);
5981 const uint32_t tiles_m = CEIL_DIV(m, mmp->a_m->wg_denoms[0]) * CEIL_DIV(n, mmp->a_m->wg_denoms[1]);
5982
5983 // Use large shader when the N dimension is greater than the medium shader's tile size
5984 uint32_t crossover_large = mmp->m->wg_denoms[1];
5985
5986 // Prefer large over medium if either:
5987 // - medium or large tiles would overfill the GPU
5988 // - large tiles with a split_k==3 fits in the GPU and medium tiles with split_k==2 does not
5989 // (medium with split_k==2 is probably better if it fits - more workgroups running and less split_k overhead)
5990 bool prefer_large = tiles_m > shader_core_count || tiles_l > shader_core_count ||
5991 // split_k==3 with large tiles likely better than medium tiles with no split_k.
5992 (tiles_l <= shader_core_count / 3 && tiles_m > shader_core_count / 2);
5993
5994 if ((ctx->device->mul_mat_l[src0_type] && (n > crossover_large && prefer_large)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_s[src0_type])) {
5995 return aligned ? mmp->a_l : mmp->l;
5996 }
5997 // Use medium shader when the N dimension is greater than the small shader's tile size
5998 uint32_t crossover_medium = mmp->s->wg_denoms[1];
5999 if ((ctx->device->mul_mat_m[src0_type] && (n > crossover_medium)) || !ctx->device->mul_mat_s[src0_type]) {
6000 return aligned ? mmp->a_m : mmp->m;
6001 }
6002 return aligned ? mmp->a_s : mmp->s;
6003 }
6004
6005 if ((ctx->device->mul_mat_s[src0_type] && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_l[src0_type])) {
6006 return aligned ? mmp->a_s : mmp->s;
6007 }
6008 if ((ctx->device->mul_mat_m[src0_type] && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_l[src0_type]) {
6009 return aligned ? mmp->a_m : mmp->m;
6010 }
6011 return aligned ? mmp->a_l : mmp->l;
6012
6013 GGML_UNUSED(src1_type);
6014}
6015
6016static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type, ggml_type src1_type) {
6017 VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ", " << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")");
6018 return ggml_vk_guess_matmul_pipeline(ctx, mmp, m, n, aligned: true, src0_type, src1_type)->align;
6019}
6020
6021static void ggml_vk_matmul(
6022 ggml_backend_vk_context * ctx, vk_context& subctx, vk_pipeline& pipeline,
6023 vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& split_k_buffer,
6024 uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d,
6025 uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d,
6026 uint32_t split_k, uint32_t batch, uint32_t ne02, uint32_t ne12, uint32_t broadcast2, uint32_t broadcast3,
6027 uint32_t padded_n) {
6028 VK_LOG_DEBUG("ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ", padded_n: " << padded_n << ")");
6029 if (split_k == 1) {
6030 const vk_mat_mat_push_constants pc = { .M: m, .N: n, .K: k, .stride_a: stride_a, .stride_b: stride_b, .stride_d: stride_d, .batch_stride_a: batch_stride_a, .batch_stride_b: batch_stride_b, .batch_stride_d: batch_stride_d, .k_split: k, .ne02: ne02, .ne12: ne12, .broadcast2: broadcast2, .broadcast3: broadcast3, .padded_N: padded_n };
6031 ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, descriptor_buffer_infos: { a, b, d }, push_constants: pc, elements: { m, n, batch });
6032 return;
6033 }
6034
6035 if (ctx->prealloc_split_k_need_sync) {
6036 ggml_vk_sync_buffers(ctx, subctx);
6037 }
6038
6039 GGML_ASSERT(batch_stride_d == m * n);
6040
6041 // Round the split size up to a multiple of 256 (k-quant alignment)
6042 uint32_t k_split = CEIL_DIV(k, split_k);
6043 k_split = ROUNDUP_POW2(k_split, 256);
6044
6045 const vk_mat_mat_push_constants pc1 = { .M: m, .N: n, .K: k, .stride_a: stride_a, .stride_b: stride_b, .stride_d: stride_d, .batch_stride_a: batch_stride_a, .batch_stride_b: batch_stride_b, .batch_stride_d: batch_stride_d, .k_split: k_split, .ne02: ne02, .ne12: ne12, .broadcast2: broadcast2, .broadcast3: broadcast3, .padded_N: padded_n };
6046 // Make sure enough workgroups get assigned for split k to work
6047 ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, descriptor_buffer_infos: { a, b, split_k_buffer }, push_constants: pc1, elements: { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, batch });
6048 ggml_vk_sync_buffers(ctx, subctx);
6049 const std::array<uint32_t, 2> pc2 = { (uint32_t)(m * n * batch), split_k };
6050 ggml_vk_dispatch_pipeline(ctx, subctx, pipeline&: ctx->device->pipeline_matmul_split_k_reduce, descriptor_buffer_infos: { split_k_buffer, d }, push_constants: pc2, elements: { m * n * batch, 1, 1 });
6051 ctx->prealloc_split_k_need_sync = true;
6052}
6053
6054static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type) {
6055 VK_LOG_DEBUG("ggml_vk_guess_matmul_id_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")");
6056
6057 if (ctx->device->coopmat2) {
6058 // Use large shader when the N dimension is greater than the medium shader's tile size
6059 uint32_t crossover_large = mmp->m->wg_denoms[1];
6060 if ((ctx->device->mul_mat_id_l[src0_type] && (n > crossover_large)) || (!ctx->device->mul_mat_id_m[src0_type] && !ctx->device->mul_mat_id_s[src0_type])) {
6061 return aligned ? mmp->a_l : mmp->l;
6062 }
6063 // Use medium shader when the N dimension is greater than the small shader's tile size
6064 uint32_t crossover_medium = mmp->s->wg_denoms[1];
6065 if ((ctx->device->mul_mat_id_m[src0_type] && (n > crossover_medium)) || !ctx->device->mul_mat_id_s[src0_type]) {
6066 return aligned ? mmp->a_m : mmp->m;
6067 }
6068 return aligned ? mmp->a_s : mmp->s;
6069 }
6070
6071 if ((ctx->device->mul_mat_id_s[src0_type] && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_id_m[src0_type] && !ctx->device->mul_mat_id_l[src0_type])) {
6072 return aligned ? mmp->a_s : mmp->s;
6073 }
6074 if ((ctx->device->mul_mat_id_m[src0_type] && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_id_l[src0_type]) {
6075 return aligned ? mmp->a_m : mmp->m;
6076 }
6077 return aligned ? mmp->a_l : mmp->l;
6078}
6079
6080static uint32_t ggml_vk_guess_matmul_id_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type) {
6081 VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ", " << ggml_type_name(src0_type) << ")");
6082 return ggml_vk_guess_matmul_id_pipeline(ctx, mmp, m, n, aligned: true, src0_type)->align;
6083}
6084
6085static void ggml_vk_matmul_id(
6086 ggml_backend_vk_context * ctx, vk_context& subctx, vk_pipeline& pipeline,
6087 vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& ids,
6088 uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d,
6089 uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d,
6090 uint32_t n_as, uint32_t nei0, uint32_t nei1, uint32_t nbi1, uint32_t ne11,
6091 uint32_t padded_n) {
6092 VK_LOG_DEBUG("ggml_vk_matmul_id(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), ids: (" << ids.buffer->buffer << ", " << ids.offset << ", " << ids.size << "), " <<
6093 "m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", " <<
6094 "batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", " <<
6095 "n_as: " << n_as << ", nei0: " << nei0 << ", nei1: " << nei1 << ", nbi1: " << nbi1 << ", ne11: " << ne11 << ")");
6096 const vk_mat_mat_id_push_constants pc = { .M: m, .N: n, .K: k, .stride_a: stride_a, .stride_b: stride_b, .stride_d: stride_d, .batch_stride_a: batch_stride_a, .batch_stride_b: batch_stride_b, .batch_stride_d: batch_stride_d,
6097 .nei0: nei0, .nei1: nei1, .nbi1: nbi1, .ne11: ne11, .padded_N: padded_n };
6098 ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, descriptor_buffer_infos: { a, b, d, ids }, push_constants: pc, elements: { m, nei1, n_as });
6099}
6100
6101static bool ggml_vk_dim01_contiguous(const ggml_tensor * tensor) {
6102 return
6103 tensor->nb[0] == ggml_type_size(type: tensor->type) &&
6104 tensor->nb[1] == (tensor->nb[0]*tensor->ne[0])/ggml_blck_size(type: tensor->type) &&
6105 (tensor->ne[3] == 1 || tensor->nb[3] == tensor->nb[2]*tensor->ne[2]);
6106}
6107
6108static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src, const ggml_tensor * dst, ggml_type to) {
6109
6110 // Choose "contiguous copy" shader if src/dst are contiguous
6111 bool contig = ggml_is_contiguous(tensor: src) && (!dst || ggml_is_contiguous(tensor: dst));
6112
6113 if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_F32) {
6114 if (contig) {
6115 return ctx->device->pipeline_contig_cpy_f32_f32;
6116 } else {
6117 return ctx->device->pipeline_cpy_f32_f32;
6118 }
6119 }
6120 if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_F16) {
6121 if (contig) {
6122 return ctx->device->pipeline_contig_cpy_f32_f16;
6123 } else {
6124 return ctx->device->pipeline_cpy_f32_f16;
6125 }
6126 }
6127 if (src->type == GGML_TYPE_F16 && to == GGML_TYPE_F16) {
6128 if (contig) {
6129 return ctx->device->pipeline_contig_cpy_f16_f16;
6130 } else {
6131 return ctx->device->pipeline_cpy_f16_f16;
6132 }
6133 }
6134 if (src->type == GGML_TYPE_F16 && to == GGML_TYPE_F32) {
6135 if (contig) {
6136 return ctx->device->pipeline_contig_cpy_f16_f32;
6137 } else {
6138 return ctx->device->pipeline_cpy_f16_f32;
6139 }
6140 }
6141 if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_BF16) {
6142 if (contig) {
6143 return ctx->device->pipeline_contig_cpy_f32_bf16;
6144 } else {
6145 return ctx->device->pipeline_cpy_f32_bf16;
6146 }
6147 }
6148 if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_I32) {
6149 if (contig) {
6150 return ctx->device->pipeline_contig_cpy_f32_i32;
6151 } else {
6152 return ctx->device->pipeline_cpy_f32_i32;
6153 }
6154 }
6155 if (src->type == GGML_TYPE_I32 && to == GGML_TYPE_F32) {
6156 if (contig) {
6157 return ctx->device->pipeline_contig_cpy_i32_f32;
6158 } else {
6159 return ctx->device->pipeline_cpy_i32_f32;
6160 }
6161 }
6162 if (src->type == GGML_TYPE_F32) {
6163 switch (to) {
6164 case GGML_TYPE_Q4_0:
6165 case GGML_TYPE_Q4_1:
6166 case GGML_TYPE_Q5_0:
6167 case GGML_TYPE_Q5_1:
6168 case GGML_TYPE_Q8_0:
6169 case GGML_TYPE_IQ4_NL:
6170 return ctx->device->pipeline_cpy_f32_quant[to];
6171 default:
6172 break;
6173 }
6174 }
6175
6176 if (to == GGML_TYPE_F32) {
6177 switch (src->type) {
6178 case GGML_TYPE_Q4_0:
6179 case GGML_TYPE_Q4_1:
6180 case GGML_TYPE_Q5_0:
6181 case GGML_TYPE_Q5_1:
6182 case GGML_TYPE_Q8_0:
6183 case GGML_TYPE_IQ4_NL:
6184 return ctx->device->pipeline_cpy_quant_f32[src->type];
6185 default:
6186 break;
6187 }
6188 }
6189
6190 if (src->type == to) {
6191 // Copy two or four bytes at a time, depending on block size.
6192 // For quantized types, we scale by block size/type size. But
6193 // this path is also used for bf16->bf16 for example, where the
6194 // type size must be exactly 2 or 4.
6195 GGML_ASSERT(ggml_is_quantized(to) || ggml_type_size(src->type) == 2 || ggml_type_size(src->type) == 4);
6196 if ((ggml_type_size(type: src->type) % 4) == 0) {
6197 if (contig) {
6198 return ctx->device->pipeline_contig_cpy_f32_f32;
6199 } else {
6200 return ctx->device->pipeline_cpy_f32_f32;
6201 }
6202 } else {
6203 if (contig) {
6204 return ctx->device->pipeline_contig_cpy_f16_f16;
6205 } else {
6206 return ctx->device->pipeline_cpy_f16_f16;
6207 }
6208 }
6209 }
6210
6211 std::cerr << "Missing CPY op for types: " << ggml_type_name(type: src->type) << " " << ggml_type_name(type: to) << std::endl;
6212 GGML_ABORT("fatal error");
6213}
6214
6215static void ggml_vk_cpy_to_contiguous(ggml_backend_vk_context * ctx, vk_context& subctx, vk_pipeline pipeline, const ggml_tensor * tensor, vk_subbuffer&& in, vk_subbuffer&& out) {
6216 VK_LOG_DEBUG("ggml_vk_cpy_to_contiguous((" << tensor << ", type=" << tensor->type << ", ne0=" << tensor->ne[0] << ", ne1=" << tensor->ne[1] << ", ne2=" << tensor->ne[2] << ", ne3=" << tensor->ne[3] << ", nb0=" << tensor->nb[0] << ", nb1=" << tensor->nb[1] << ", nb2=" << tensor->nb[2] << ", nb3=" << tensor->nb[3] << "), ";
6217 std::cerr << "buffer in size=" << in.buffer->size << ", buffer out size=" << out.buffer->size << ")");
6218 const int tensor_type_size = ggml_type_size(type: tensor->type);
6219
6220 const uint32_t ne = ggml_nelements(tensor);
6221 std::array<uint32_t, 3> elements;
6222
6223 if (ne > 262144) {
6224 elements = { 512, 512, CEIL_DIV(ne, 262144) };
6225 } else if (ne > 512) {
6226 elements = { 512, CEIL_DIV(ne, 512), 1 };
6227 } else {
6228 elements = { ne, 1, 1 };
6229 }
6230
6231 vk_op_unary_push_constants pc = {
6232 .ne: (uint32_t)ne,
6233 .ne00: (uint32_t)tensor->ne[0], .ne01: (uint32_t)tensor->ne[1], .ne02: (uint32_t)tensor->ne[2], .ne03: (uint32_t)tensor->ne[3], .nb00: (uint32_t)tensor->nb[0] / tensor_type_size, .nb01: (uint32_t)tensor->nb[1] / tensor_type_size, .nb02: (uint32_t)tensor->nb[2] / tensor_type_size, .nb03: (uint32_t)tensor->nb[3] / tensor_type_size,
6234 .ne10: (uint32_t)tensor->ne[0], .ne11: (uint32_t)tensor->ne[1], .ne12: (uint32_t)tensor->ne[2], .ne13: (uint32_t)tensor->ne[3], .nb10: 1 , .nb11: (uint32_t)tensor->ne[0] , .nb12: (uint32_t)(tensor->ne[0] * tensor->ne[1]) , .nb13: (uint32_t)(tensor->ne[0] * tensor->ne[1] * tensor->ne[2]),
6235 .misalign_offsets: 0,
6236 .param1: 0.0f, .param2: 0.0f,
6237 .ne0_012mp: 0, .ne0_012L: 0, .ne0_01mp: 0, .ne0_01L: 0, .ne0_0mp: 0, .ne0_0L: 0, .ne1_012mp: 0, .ne1_012L: 0, .ne1_01mp: 0, .ne1_01L: 0, .ne1_0mp: 0, .ne1_0L: 0,
6238 };
6239 init_pushconst_fastdiv(p&: pc);
6240 ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, descriptor_buffer_infos: { in, out }, push_constants: pc, elements);
6241 ggml_vk_sync_buffers(ctx, subctx);
6242}
6243
6244static vk_pipeline ggml_vk_get_quantize_pipeline(ggml_backend_vk_context * ctx, ggml_type type, bool use_x4_blocks) {
6245 switch(type) {
6246 case GGML_TYPE_Q8_1:
6247 return use_x4_blocks ? ctx->device->pipeline_quantize_q8_1_x4 : ctx->device->pipeline_quantize_q8_1;
6248 default:
6249 std::cerr << "Missing quantize pipeline for type: " << ggml_type_name(type) << std::endl;
6250 GGML_ABORT("fatal error");
6251 }
6252}
6253
6254static void ggml_vk_quantize_q8_1(ggml_backend_vk_context * ctx, vk_context& subctx, vk_subbuffer&& in, vk_subbuffer&& out, uint32_t ne, bool use_x4_blocks = false) {
6255 VK_LOG_DEBUG("ggml_vk_quantize_q8_1(" << "buffer in size=" << in.buffer->size << ", buffer out size=" << out.buffer->size << ", " << ne << ")");
6256
6257 vk_pipeline pipeline = use_x4_blocks ? ggml_vk_get_quantize_pipeline(ctx, type: GGML_TYPE_Q8_1, use_x4_blocks: true) : ggml_vk_get_quantize_pipeline(ctx, type: GGML_TYPE_Q8_1, use_x4_blocks: false);
6258
6259 ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, descriptor_buffer_infos: { in, out }, push_constants: std::array<uint32_t, 1>{ne}, elements: { ne, 1, 1 });
6260 ggml_vk_sync_buffers(ctx, subctx);
6261}
6262
6263static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool disable_split_k) {
6264 VK_LOG_DEBUG("ggml_vk_mul_mat_q_f16((" << src0 << ", name=" << src0->name << ", type=" << ggml_type_name(src0->type) << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
6265 std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << ggml_type_name(src1->type) << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
6266 std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << ggml_type_name(dst->type) << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
6267 std::cerr << "))");
6268 GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16); // NOLINT
6269 GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT
6270
6271 const uint64_t ne00 = src0->ne[0];
6272 const uint64_t ne01 = src0->ne[1];
6273 const uint64_t ne02 = src0->ne[2];
6274 const uint64_t ne03 = src0->ne[3];
6275
6276 const uint64_t ne10 = src1->ne[0];
6277 const uint64_t ne11 = src1->ne[1];
6278 const uint64_t ne12 = src1->ne[2];
6279 const uint64_t ne13 = src1->ne[3];
6280
6281 const uint64_t ne21 = dst->ne[1];
6282 const uint32_t stride_d = dst->nb[1] / ggml_type_size(type: dst->type);
6283 const uint32_t stride_batch_d = stride_d*ne21;
6284
6285 const uint64_t r2 = ne12 / ne02;
6286 const uint64_t r3 = ne13 / ne03;
6287
6288 ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
6289 ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context;
6290 ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context;
6291
6292 vk_buffer d_Qx = nullptr;
6293 size_t qx_buf_offset = 0;
6294 vk_buffer d_Qy = nullptr;
6295 size_t qy_buf_offset = 0;
6296
6297 bool src0_uma = false;
6298 bool src1_uma = false;
6299
6300 if (ctx->device->uma) {
6301 ggml_vk_host_get(device: ctx->device, ptr: src0->data, buf&: d_Qx, buf_offset&: qx_buf_offset);
6302 ggml_vk_host_get(device: ctx->device, ptr: src1->data, buf&: d_Qy, buf_offset&: qy_buf_offset);
6303 src0_uma = d_Qx != nullptr;
6304 src1_uma = d_Qy != nullptr;
6305 }
6306
6307 // Reformat and convert to fp16 if non-contiguous, or for coopmat2 for better perf
6308 const bool x_non_contig = (ctx->device->coopmat2 && src0->type == GGML_TYPE_F32) ||
6309 !ggml_vk_dim01_contiguous(tensor: src0);
6310 const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) ||
6311 (src0->type == GGML_TYPE_BF16 && src1->type != GGML_TYPE_BF16) ||
6312 !ggml_vk_dim01_contiguous(tensor: src1);
6313
6314 // If src0 is BF16, try to use a BF16 x BF16 multiply
6315 ggml_type f16_type = src0->type == GGML_TYPE_BF16 ? GGML_TYPE_BF16 : GGML_TYPE_F16;
6316
6317 const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
6318
6319 bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(tensor: src1) && (ne11 * ne10) % 4 == 0;
6320
6321 // Check for mmq first
6322 vk_matmul_pipeline mmp = quantize_y ? ggml_vk_get_mul_mat_mat_pipeline(ctx, src0_type: src0->type, src1_type: GGML_TYPE_Q8_1, prec: (ggml_prec)dst->op_params[0]) : nullptr;
6323
6324 if (mmp == nullptr) {
6325 // Fall back to f16 dequant mul mat
6326 mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, src0_type: src0->type, src1_type: y_non_contig ? f16_type : src1->type, prec: (ggml_prec)dst->op_params[0]);
6327 quantize_y = false;
6328 }
6329
6330 const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
6331 const bool qy_needs_dequant = !quantize_y && ((src1->type != f16_type && !y_f32_kernel) || y_non_contig);
6332
6333 if (qx_needs_dequant) {
6334 // Fall back to dequant + f16 mulmat
6335 mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, src0_type: f16_type, src1_type: y_f32_kernel ? GGML_TYPE_F32 : f16_type, prec: (ggml_prec)dst->op_params[0]);
6336 }
6337
6338 // Not implemented
6339 GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
6340
6341 const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(width: ne10, align: ggml_vk_guess_matmul_pipeline_align(ctx, mmp, m: ne01, n: ne11, src0_type: qx_needs_dequant ? f16_type : src0->type, src1_type: quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type)));
6342 const bool aligned = !quantize_y && ne10 == kpad && ne01 > 8 && ne11 > 8;
6343
6344 vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, m: ne01, n: ne11, aligned, src0_type: qx_needs_dequant ? f16_type : src0->type, src1_type: quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type));
6345
6346 // Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking
6347 uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) : ne11;
6348 const int x_ne = ne01 * ne00;
6349 const int y_ne = padded_n * ne10;
6350 const int d_ne = ne11 * ne01;
6351
6352 const uint32_t split_k = ggml_vk_guess_split_k(ctx, m: ne01, n: ne11, k: ne10, disable_split_k, pipeline);
6353
6354 const uint64_t qx_sz = ggml_type_size(type: src0->type) * x_ne / ggml_blck_size(type: src0->type);
6355 const uint64_t qy_sz = ggml_type_size(type: src1->type) * y_ne / ggml_blck_size(type: src1->type);
6356 const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne;
6357 const uint64_t y_sz = quantize_y ? (y_ne * ggml_type_size(type: GGML_TYPE_Q8_1) / ggml_blck_size(type: GGML_TYPE_Q8_1)) : (y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne);
6358 const uint64_t d_sz = sizeof(float) * d_ne;
6359
6360 vk_pipeline to_fp16_vk_0 = nullptr;
6361 vk_pipeline to_fp16_vk_1 = nullptr;
6362 vk_pipeline to_q8_1 = nullptr;
6363
6364 if (x_non_contig) {
6365 to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src: src0, dst: nullptr, to: f16_type);
6366 } else {
6367 to_fp16_vk_0 = ggml_vk_get_to_fp16(ctx, type: src0->type);
6368 }
6369 if (y_non_contig) {
6370 to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src: src1, dst: nullptr, to: f16_type);
6371 } else {
6372 to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, type: src1->type);
6373 }
6374 GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT
6375 GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT
6376
6377 if (quantize_y) {
6378 to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, type: GGML_TYPE_Q8_1, use_x4_blocks: true);
6379 }
6380
6381 {
6382 const uint64_t x_sz_upd = x_sz * ne02 * ne03;
6383 uint64_t y_sz_upd = y_sz * ne12 * ne13;
6384 if (quantize_y) {
6385 y_sz_upd = CEIL_DIV(y_sz_upd, 144) * 144;
6386 }
6387 const uint64_t split_k_size = split_k > 1 ? d_sz * ne12 * ne13 * split_k : 0;
6388 if (
6389 (qx_needs_dequant && x_sz_upd > ctx->device->properties.limits.maxStorageBufferRange) ||
6390 (qy_needs_dequant && y_sz_upd > ctx->device->properties.limits.maxStorageBufferRange) ||
6391 (split_k > 1 && split_k_size > ctx->device->properties.limits.maxStorageBufferRange)) {
6392 GGML_ABORT("Requested preallocation size is too large");
6393 }
6394 if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) {
6395 ctx->prealloc_size_x = x_sz_upd;
6396 ggml_vk_preallocate_buffers(ctx, subctx);
6397 }
6398 if ((qy_needs_dequant || quantize_y) && ctx->prealloc_size_y < y_sz_upd) {
6399 ctx->prealloc_size_y = y_sz_upd;
6400 ggml_vk_preallocate_buffers(ctx, subctx);
6401 }
6402 if (split_k > 1 && ctx->prealloc_size_split_k < split_k_size) {
6403 ctx->prealloc_size_split_k = split_k_size;
6404 ggml_vk_preallocate_buffers(ctx, subctx);
6405 }
6406
6407 // Request descriptor sets
6408 ggml_pipeline_request_descriptor_sets(ctx, pipeline, n: 1);
6409 if (qx_needs_dequant) {
6410 ggml_pipeline_request_descriptor_sets(ctx, pipeline&: to_fp16_vk_0, n: 1);
6411 }
6412 if (qy_needs_dequant) {
6413 ggml_pipeline_request_descriptor_sets(ctx, pipeline&: to_fp16_vk_1, n: 1);
6414 }
6415 if (quantize_y) {
6416 ggml_pipeline_request_descriptor_sets(ctx, pipeline&: to_q8_1, n: 1);
6417 }
6418 if (split_k > 1) {
6419 ggml_pipeline_request_descriptor_sets(ctx, pipeline&: ctx->device->pipeline_matmul_split_k_reduce, n: 1);
6420 }
6421 }
6422
6423 vk_buffer d_D = dst_buf_ctx->dev_buffer;
6424 const uint64_t d_buf_offset = vk_tensor_offset(tensor: dst) + dst->view_offs;
6425 GGML_ASSERT(d_D != nullptr);
6426 GGML_ASSERT(d_D->size >= d_buf_offset + d_sz * ne02 * ne03);
6427 vk_buffer d_X;
6428 uint64_t x_buf_offset = 0;
6429 vk_buffer d_Y;
6430 uint64_t y_buf_offset = 0;
6431 if (!src0_uma) {
6432 d_Qx = src0_buf_ctx->dev_buffer;
6433 qx_buf_offset = vk_tensor_offset(tensor: src0) + src0->view_offs;
6434 GGML_ASSERT(d_Qx != nullptr);
6435 }
6436 if (!src1_uma) {
6437 d_Qy = src1_buf_ctx->dev_buffer;
6438 qy_buf_offset = vk_tensor_offset(tensor: src1) + src1->view_offs;
6439 GGML_ASSERT(d_Qy != nullptr);
6440 }
6441 if (qx_needs_dequant) {
6442 d_X = ctx->prealloc_x;
6443 GGML_ASSERT(d_X->size >= x_sz * ne02 * ne03);
6444 } else {
6445 d_X = d_Qx;
6446 x_buf_offset = qx_buf_offset;
6447 GGML_ASSERT(qx_sz == x_sz);
6448 }
6449 if (qy_needs_dequant) {
6450 d_Y = ctx->prealloc_y;
6451 GGML_ASSERT(d_Y->size >= y_sz * ne12 * ne13);
6452 } else if (quantize_y) {
6453 d_Y = ctx->prealloc_y;
6454 GGML_ASSERT(d_Y->size >= CEIL_DIV(y_sz * ne12 * ne13, 144) * 144);
6455 } else {
6456 d_Y = d_Qy;
6457 y_buf_offset = qy_buf_offset;
6458 GGML_ASSERT(qy_sz == y_sz);
6459 }
6460
6461 if (x_non_contig || qx_needs_dequant) {
6462 if (ctx->prealloc_x_need_sync) {
6463 ggml_vk_sync_buffers(ctx, subctx);
6464 }
6465 }
6466
6467 if (x_non_contig) {
6468 ggml_vk_cpy_to_contiguous(ctx, subctx, pipeline: to_fp16_vk_0, tensor: src0, in: ggml_vk_subbuffer(ctx, buf: d_Qx, offset: qx_buf_offset), out: ggml_vk_subbuffer(ctx, buf: d_X, offset: 0));
6469 } else if (qx_needs_dequant) {
6470 const std::vector<uint32_t> pc = { (uint32_t)ne01, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(ggml_nelements(tensor: src0)) };
6471 ggml_vk_dispatch_pipeline(ctx, subctx, pipeline&: to_fp16_vk_0, descriptor_buffer_infos: { vk_subbuffer{ .buffer: d_Qx, .offset: qx_buf_offset, .size: qx_sz * ne02 * ne03 }, vk_subbuffer{ .buffer: d_X, .offset: 0, .size: x_sz * ne02 * ne03 } }, push_constants: pc, elements: { (uint32_t)(x_ne * ne02 * ne03), 1, 1});
6472 ggml_vk_sync_buffers(ctx, subctx);
6473 }
6474 if (y_non_contig) {
6475 if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() ||
6476 ctx->prealloc_y_last_tensor_used != src1) {
6477 if (ctx->prealloc_y_need_sync) {
6478 ggml_vk_sync_buffers(ctx, subctx);
6479 }
6480 ggml_vk_cpy_to_contiguous(ctx, subctx, pipeline: to_fp16_vk_1, tensor: src1, in: ggml_vk_subbuffer(ctx, buf: d_Qy, offset: qy_buf_offset), out: ggml_vk_subbuffer(ctx, buf: d_Y, offset: 0));
6481 ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get();
6482 ctx->prealloc_y_last_tensor_used = src1;
6483 }
6484 }
6485 if (quantize_y) {
6486 if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() ||
6487 ctx->prealloc_y_last_tensor_used != src1) {
6488 if (ctx->prealloc_y_need_sync) {
6489 ggml_vk_sync_buffers(ctx, subctx);
6490 }
6491 ggml_vk_quantize_q8_1(ctx, subctx, in: ggml_vk_subbuffer(ctx, buf: d_Qy, offset: qy_buf_offset), out: ggml_vk_subbuffer(ctx, buf: d_Y, offset: 0), ne: y_ne * ne12 * ne13, use_x4_blocks: true);
6492 ctx->prealloc_y_last_pipeline_used = to_q8_1.get();
6493 ctx->prealloc_y_last_tensor_used = src1;
6494 }
6495 }
6496
6497 uint32_t stride_batch_x = ne00*ne01;
6498 uint32_t stride_batch_y = ne10*ne11;
6499
6500 if (!ggml_vk_dim01_contiguous(tensor: src0) && !qx_needs_dequant) {
6501 stride_batch_x = src0->nb[0] / ggml_type_size(type: src0->type);
6502 }
6503
6504 if (!ggml_vk_dim01_contiguous(tensor: src1) && !qy_needs_dequant && !quantize_y) {
6505 stride_batch_y = src1->nb[0] / ggml_type_size(type: src1->type);
6506 }
6507
6508 uint32_t y_sz_total = y_sz * ne12 * ne13;
6509 if (quantize_y) {
6510 y_sz_total = CEIL_DIV(y_sz_total, 144) * 144;
6511 }
6512
6513 // compute
6514 ggml_vk_matmul(
6515 ctx, subctx, pipeline,
6516 a: { .buffer: d_X, .offset: x_buf_offset, .size: x_sz * ne02 * ne03 }, b: { .buffer: d_Y, .offset: y_buf_offset, .size: y_sz_total },
6517 d: ggml_vk_subbuffer(ctx, buf: d_D, offset: d_buf_offset), split_k_buffer: { .buffer: ctx->prealloc_split_k, .offset: 0, .size: d_sz * ne12 * ne13 * split_k },
6518 m: ne01, n: ne11, k: ne10,
6519 stride_a: ne10, stride_b: ne10, stride_d, batch_stride_a: stride_batch_x, batch_stride_b: stride_batch_y, batch_stride_d: stride_batch_d,
6520 split_k, batch: ne12*ne13, ne02, ne12, broadcast2: r2, broadcast3: r3, padded_n
6521 ); // NOLINT
6522
6523 if (x_non_contig || qx_needs_dequant) {
6524 ctx->prealloc_x_need_sync = true;
6525 }
6526 if (y_non_contig || quantize_y) {
6527 ctx->prealloc_y_need_sync = true;
6528 }
6529}
6530
6531// Device tuning
6532static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_t n, uint32_t k, ggml_type src0_type) {
6533 if (device->mmvq_mode == 1) {
6534 return true;
6535 } else if (device->mmvq_mode == -1) {
6536 return false;
6537 }
6538
6539 // MMVQ is generally good for batches
6540 if (n > 1) {
6541 return true;
6542 }
6543
6544 switch (device->vendor_id) {
6545 case VK_VENDOR_ID_NVIDIA:
6546 switch (src0_type) {
6547 case GGML_TYPE_Q8_0:
6548 return device->architecture == vk_device_architecture::NVIDIA_PRE_TURING;
6549 default:
6550 return true;
6551 }
6552 case VK_VENDOR_ID_AMD:
6553 switch (src0_type) {
6554 case GGML_TYPE_Q8_0:
6555 return device->architecture == vk_device_architecture::AMD_GCN;
6556 default:
6557 return true;
6558 }
6559 case VK_VENDOR_ID_INTEL:
6560 switch (src0_type) {
6561 // From tests on A770 Linux, may need more tuning
6562 case GGML_TYPE_Q4_0:
6563 case GGML_TYPE_Q5_1:
6564 return false;
6565 default:
6566 return true;
6567 }
6568 default:
6569 return true;
6570 }
6571
6572 GGML_UNUSED(m);
6573 GGML_UNUSED(k);
6574}
6575
6576static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) {
6577 ggml_tensor * dst = cgraph->nodes[node_idx];
6578 const ggml_tensor * src0 = dst->src[0];
6579 const ggml_tensor * src1 = dst->src[1];
6580
6581 VK_LOG_DEBUG("ggml_vk_mul_mat_vec_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
6582 std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
6583 std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
6584 std::cerr << ")),)");
6585 GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16); // NOLINT
6586 GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT
6587
6588 const uint64_t ne00 = src0->ne[0];
6589 const uint64_t ne01 = src0->ne[1];
6590 const uint64_t ne02 = src0->ne[2];
6591 const uint64_t ne03 = src0->ne[3];
6592
6593 const uint64_t ne10 = src1->ne[0];
6594 const uint64_t ne11 = src1->ne[1];
6595 const uint64_t ne12 = src1->ne[2];
6596 const uint64_t ne13 = src1->ne[3];
6597
6598 const uint64_t ne20 = dst->ne[0];
6599 const uint64_t ne21 = dst->ne[1];
6600 const uint64_t ne22 = dst->ne[2];
6601 const uint64_t ne23 = dst->ne[3];
6602
6603 const uint64_t r2 = ne12 / ne02;
6604 const uint64_t r3 = ne13 / ne03;
6605
6606 // batch_n indicates that we need to compute a few vector results, and this assumes
6607 // ne12 and ne13 are 1. It overloads the batch_strides to hold the row strides.
6608 GGML_ASSERT(ne11 == 1 || ne12 * ne13 == 1);
6609 bool batch_n = ne11 > 1;
6610
6611 ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context;
6612 ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context;
6613
6614 vk_buffer d_Qx = nullptr;
6615 size_t qx_buf_offset = 0;
6616 vk_buffer d_Qy = nullptr;
6617 size_t qy_buf_offset = 0;
6618
6619 bool src0_uma = false;
6620 bool src1_uma = false;
6621
6622 if (ctx->device->uma) {
6623 ggml_vk_host_get(device: ctx->device, ptr: src0->data, buf&: d_Qx, buf_offset&: qx_buf_offset);
6624 ggml_vk_host_get(device: ctx->device, ptr: src1->data, buf&: d_Qy, buf_offset&: qy_buf_offset);
6625 src0_uma = d_Qx != nullptr;
6626 src1_uma = d_Qy != nullptr;
6627 }
6628
6629 const bool x_non_contig = !ggml_vk_dim01_contiguous(tensor: src0);
6630 const bool y_non_contig = !ggml_vk_dim01_contiguous(tensor: src1);
6631
6632 const bool f16_f32_kernel = src1->type == GGML_TYPE_F32;
6633 bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(tensor: src1) && (ne11 * ne10) % 4 == 0 && ggml_vk_should_use_mmvq(device: ctx->device, m: ne01, n: ne11, k: ne10, src0_type: src0->type);
6634
6635 vk_pipeline to_fp16_vk_0 = nullptr;
6636 vk_pipeline to_fp16_vk_1 = nullptr;
6637 if (x_non_contig) {
6638 to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src: src0, dst: nullptr, to: src0->type);
6639 }
6640 if (y_non_contig) {
6641 to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src: src1, dst: nullptr, to: src1->type);
6642 } else {
6643 to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, type: src1->type);
6644 }
6645
6646 // Check for mmq first
6647 vk_pipeline dmmv = quantize_y ? ggml_vk_get_dequantize_mul_mat_vec(ctx, a_type: src0->type, b_type: GGML_TYPE_Q8_1, num_cols: ne11, m: ne20, k: ne00) : nullptr;
6648 vk_pipeline to_q8_1 = nullptr;
6649
6650 if (dmmv == nullptr) {
6651 // Fall back to f16 dequant mul mat
6652 dmmv = ggml_vk_get_dequantize_mul_mat_vec(ctx, a_type: src0->type, b_type: src1->type, num_cols: ne11, m: ne20, k: ne00);
6653 quantize_y = false;
6654 }
6655
6656 if (quantize_y) {
6657 to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, type: GGML_TYPE_Q8_1, use_x4_blocks: true);
6658 }
6659
6660 const bool qx_needs_dequant = x_non_contig;
6661 const bool qy_needs_dequant = !quantize_y && ((src1->type != GGML_TYPE_F16 && !f16_f32_kernel) || y_non_contig);
6662
6663 // Not implemented
6664 GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
6665
6666 GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT
6667 GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT
6668 GGML_ASSERT(dmmv != nullptr);
6669
6670 const uint64_t x_ne = ne01 * ne00;
6671 const uint64_t y_ne = ne11 * ne10;
6672 const uint64_t d_ne = ne11 * ne01;
6673
6674 const uint64_t qx_sz = ggml_vk_align_size(width: ggml_type_size(type: src0->type) * x_ne / ggml_blck_size(type: src0->type), align: ctx->device->properties.limits.minStorageBufferOffsetAlignment);
6675 const uint64_t qy_sz = ggml_type_size(type: src1->type) * y_ne / ggml_blck_size(type: src1->type);
6676 const uint64_t x_sz = x_non_contig ? ggml_vk_align_size(width: ggml_type_size(type: src0->type) * x_ne, align: ctx->device->properties.limits.minStorageBufferOffsetAlignment) : qx_sz;
6677 const uint64_t y_sz = quantize_y ? (y_ne * ggml_type_size(type: GGML_TYPE_Q8_1) / ggml_blck_size(type: GGML_TYPE_Q8_1)) : (f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne);
6678 const uint64_t d_sz = sizeof(float) * d_ne;
6679
6680 {
6681 const uint64_t x_sz_upd = x_sz * ne02 * ne03;
6682 uint64_t y_sz_upd = y_sz * ne12 * ne13;
6683 if (quantize_y) {
6684 y_sz_upd = CEIL_DIV(y_sz_upd, 144) * 144;
6685 }
6686 if (
6687 (qx_needs_dequant && x_sz_upd > ctx->device->properties.limits.maxStorageBufferRange) ||
6688 (qy_needs_dequant && y_sz_upd > ctx->device->properties.limits.maxStorageBufferRange)) {
6689 GGML_ABORT("Requested preallocation size is too large");
6690 }
6691 if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) {
6692 ctx->prealloc_size_x = x_sz_upd;
6693 ggml_vk_preallocate_buffers(ctx, subctx);
6694 }
6695 if ((qy_needs_dequant || quantize_y) && ctx->prealloc_size_y < y_sz_upd) {
6696 ctx->prealloc_size_y = y_sz_upd;
6697 ggml_vk_preallocate_buffers(ctx, subctx);
6698 }
6699
6700 // Request descriptor sets
6701 if (qx_needs_dequant) {
6702 ggml_pipeline_request_descriptor_sets(ctx, pipeline&: to_fp16_vk_0, n: 1);
6703 }
6704 if (qy_needs_dequant) {
6705 ggml_pipeline_request_descriptor_sets(ctx, pipeline&: to_fp16_vk_1, n: 1);
6706 }
6707 if (quantize_y) {
6708 ggml_pipeline_request_descriptor_sets(ctx, pipeline&: to_q8_1, n: 1);
6709 }
6710 ggml_pipeline_request_descriptor_sets(ctx, pipeline&: dmmv, n: 1);
6711 }
6712
6713 vk_buffer d_D;
6714 uint64_t d_buf_offset = 0;
6715
6716 if (ctx->num_additional_fused_ops > 0) {
6717 const ggml_tensor * add = cgraph->nodes[node_idx + 1];
6718 ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)add->buffer->context;
6719 d_D = dst_buf_ctx->dev_buffer;
6720 d_buf_offset = vk_tensor_offset(tensor: add) + add->view_offs;
6721 } else {
6722 ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
6723 d_D = dst_buf_ctx->dev_buffer;
6724 d_buf_offset = vk_tensor_offset(tensor: dst) + dst->view_offs;
6725 }
6726
6727 GGML_ASSERT(d_D != nullptr);
6728 vk_buffer d_X;
6729 uint64_t x_buf_offset = 0;
6730 vk_buffer d_Y;
6731 uint64_t y_buf_offset = 0;
6732 if(!src0_uma) {
6733 d_Qx = src0_buf_ctx->dev_buffer;
6734 qx_buf_offset = vk_tensor_offset(tensor: src0) + src0->view_offs;
6735 GGML_ASSERT(d_Qx != nullptr);
6736 }
6737 if(!src1_uma) {
6738 d_Qy = src1_buf_ctx->dev_buffer;
6739 qy_buf_offset = vk_tensor_offset(tensor: src1) + src1->view_offs;
6740 GGML_ASSERT(d_Qy != nullptr);
6741 }
6742 if (qx_needs_dequant) {
6743 d_X = ctx->prealloc_x;
6744 } else {
6745 d_X = d_Qx;
6746 x_buf_offset = qx_buf_offset;
6747 GGML_ASSERT(qx_sz == x_sz);
6748 }
6749 if (qy_needs_dequant) {
6750 d_Y = ctx->prealloc_y;
6751 } else if (quantize_y) {
6752 d_Y = ctx->prealloc_y;
6753 GGML_ASSERT(d_Y->size >= CEIL_DIV(y_sz * ne12 * ne13, 144) * 144);
6754 } else {
6755 d_Y = d_Qy;
6756 y_buf_offset = qy_buf_offset;
6757 GGML_ASSERT(qy_sz == y_sz);
6758 }
6759
6760 if (x_non_contig) {
6761 if (ctx->prealloc_x_need_sync) {
6762 ggml_vk_sync_buffers(ctx, subctx);
6763 }
6764
6765 GGML_ASSERT(x_sz == ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment));
6766 ggml_vk_cpy_to_contiguous(ctx, subctx, pipeline: to_fp16_vk_0, tensor: src0, in: ggml_vk_subbuffer(ctx, buf: d_Qx, offset: qx_buf_offset), out: ggml_vk_subbuffer(ctx, buf: d_X, offset: 0));
6767 }
6768 if (y_non_contig) {
6769 GGML_ASSERT(y_sz == ggml_type_size(src1->type) * y_ne);
6770 if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() ||
6771 ctx->prealloc_y_last_tensor_used != src1) {
6772 if (ctx->prealloc_y_need_sync) {
6773 ggml_vk_sync_buffers(ctx, subctx);
6774 }
6775 ggml_vk_cpy_to_contiguous(ctx, subctx, pipeline: to_fp16_vk_1, tensor: src1, in: ggml_vk_subbuffer(ctx, buf: d_Qy, offset: qy_buf_offset), out: ggml_vk_subbuffer(ctx, buf: d_Y, offset: 0));
6776 ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get();
6777 ctx->prealloc_y_last_tensor_used = src1;
6778 }
6779 }
6780 if (quantize_y) {
6781 if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() ||
6782 ctx->prealloc_y_last_tensor_used != src1) {
6783 if (ctx->prealloc_y_need_sync) {
6784 ggml_vk_sync_buffers(ctx, subctx);
6785 }
6786 ggml_vk_quantize_q8_1(ctx, subctx, in: ggml_vk_subbuffer(ctx, buf: d_Qy, offset: qy_buf_offset), out: ggml_vk_subbuffer(ctx, buf: d_Y, offset: 0), ne: y_ne * ne12 * ne13, use_x4_blocks: true);
6787 ctx->prealloc_y_last_pipeline_used = to_q8_1.get();
6788 ctx->prealloc_y_last_tensor_used = src1;
6789 }
6790 }
6791
6792 // For batch_n, the A matrix is the same for each batch, and B/D use the row stride as the batch stride
6793 uint32_t stride_batch_x = batch_n ? 0 : ne00*ne01;
6794 uint32_t stride_batch_y = batch_n ? ne10 : (ne10*ne11);
6795 uint32_t stride_batch_d = batch_n ? ne20 : (ne20*ne21);
6796
6797 if (!ggml_vk_dim01_contiguous(tensor: src0) && !qx_needs_dequant) {
6798 stride_batch_x = src0->nb[0] / ggml_type_size(type: src0->type);
6799 }
6800
6801 if (!ggml_vk_dim01_contiguous(tensor: src1) && !qy_needs_dequant) {
6802 stride_batch_y = src1->nb[0] / ggml_type_size(type: src1->type);
6803 }
6804
6805 const uint32_t max_groups_x = ctx->device->properties.limits.maxComputeWorkGroupCount[0];
6806
6807 uint32_t groups_x = ne01;
6808 uint32_t groups_z = 1;
6809
6810 if (ne01 > max_groups_x) {
6811 groups_z = 64;
6812 groups_x = CEIL_DIV(groups_x, groups_z);
6813 }
6814
6815 // TODO: Clean up this whole sz * ne_2 * ne_3 thing, it hasn't been necessary for a long time
6816 uint32_t y_sz_total = y_sz * ne12 * ne13;
6817 if (quantize_y) {
6818 y_sz_total = CEIL_DIV(y_sz_total, 144) * 144;
6819 }
6820
6821 uint32_t enable_bias = ctx->num_additional_fused_ops > 0;
6822
6823 vk_buffer d_B = d_D;
6824 size_t b_buf_offset = 0;
6825 uint64_t b_sz = 0;
6826
6827 if (enable_bias) {
6828 const ggml_tensor * add = cgraph->nodes[node_idx + 1];
6829 const ggml_tensor * bias = add->src[0] == dst ? add->src[1] : add->src[0];
6830
6831 bool b_uma = false;
6832 if (ctx->device->uma) {
6833 ggml_vk_host_get(device: ctx->device, ptr: bias->data, buf&: d_B, buf_offset&: b_buf_offset);
6834 b_uma = d_B != nullptr;
6835 }
6836 if(!b_uma) {
6837 ggml_backend_vk_buffer_context * bias_buf_ctx = (ggml_backend_vk_buffer_context *)bias->buffer->context;
6838 d_B = bias_buf_ctx->dev_buffer;
6839 b_buf_offset = vk_tensor_offset(tensor: bias) + bias->view_offs;
6840 GGML_ASSERT(d_B != nullptr);
6841 b_sz = ggml_nbytes(tensor: bias);
6842 }
6843 }
6844
6845 // compute
6846 const vk_mat_vec_push_constants pc = {
6847 .ncols: (uint32_t)ne00, .stride_a: (uint32_t)ne10, .stride_b: (uint32_t)ne10, .stride_d: (uint32_t)ne01,
6848 .batch_stride_a: stride_batch_x, .batch_stride_b: stride_batch_y, .batch_stride_d: stride_batch_d, .enable_bias: enable_bias,
6849 .ne02: (uint32_t)ne02, .ne12: (uint32_t)ne12, .broadcast2: (uint32_t)r2, .broadcast3: (uint32_t)r3,
6850 };
6851 ggml_vk_dispatch_pipeline(ctx, subctx, pipeline&: dmmv,
6852 descriptor_buffer_infos: {
6853 vk_subbuffer{ .buffer: d_X, .offset: x_buf_offset, .size: x_sz * ne02 * ne03 },
6854 vk_subbuffer{ .buffer: d_Y, .offset: y_buf_offset, .size: y_sz_total },
6855 vk_subbuffer{ .buffer: d_D, .offset: d_buf_offset, .size: d_sz * ne22 * ne23},
6856 vk_subbuffer{ .buffer: d_B, .offset: b_buf_offset, .size: b_sz },
6857 },
6858 push_constants: pc, elements: { groups_x, (uint32_t)(ne12 * ne13), groups_z });
6859
6860 if (x_non_contig) {
6861 ctx->prealloc_x_need_sync = true;
6862 }
6863 if (y_non_contig || quantize_y) {
6864 ctx->prealloc_y_need_sync = true;
6865 }
6866}
6867
6868static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) {
6869 ggml_tensor * dst = cgraph->nodes[node_idx];
6870 const ggml_tensor * src0 = dst->src[0];
6871 const ggml_tensor * src1 = dst->src[1];
6872 VK_LOG_DEBUG("ggml_vk_mul_mat_p021_f16_f32(" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
6873 std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
6874 std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
6875 std::cerr << "))");
6876 GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1));
6877 GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]); // NOLINT
6878 GGML_ASSERT(src1->nb[0] <= src1->nb[1] && src1->nb[2] <= src1->nb[3]); // NOLINT
6879 GGML_ASSERT(src0->type == GGML_TYPE_F16);
6880 GGML_ASSERT(src1->type == GGML_TYPE_F32);
6881
6882 const uint64_t ne00 = src0->ne[0];
6883 const uint64_t ne01 = src0->ne[1];
6884 const uint64_t ne02 = src0->ne[2];
6885 // const uint64_t ne03 = src0->ne[3];
6886
6887 const uint64_t ne10 = src1->ne[0];
6888 const uint64_t ne11 = src1->ne[1];
6889 const uint64_t ne12 = src1->ne[2];
6890 // const uint64_t ne13 = src1->ne[3];
6891
6892 GGML_ASSERT(ne11 == 1);
6893
6894 ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context;
6895 ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context;
6896
6897 vk_buffer d_Qy = nullptr;
6898 size_t qy_buf_offset = 0;
6899
6900 bool src1_uma = false;
6901
6902 if (ctx->device->uma) {
6903 ggml_vk_host_get(device: ctx->device, ptr: src1->data, buf&: d_Qy, buf_offset&: qy_buf_offset);
6904 src1_uma = d_Qy != nullptr;
6905 }
6906
6907 const uint64_t x_ne = ne00 * ne01 * ne02;
6908 const uint64_t y_ne = ne10 * ne11 * ne12;
6909 const uint64_t d_ne = ne01 * ne11 * ne12;
6910
6911 const uint64_t qx_sz = ggml_vk_align_size(width: ggml_type_size(type: src0->type) * x_ne / ggml_blck_size(type: src0->type), align: ctx->device->properties.limits.minStorageBufferOffsetAlignment);
6912 const uint64_t qy_sz = ggml_type_size(type: src1->type) * y_ne / ggml_blck_size(type: src1->type);
6913 const uint64_t d_sz = sizeof(float) * d_ne;
6914
6915 // With grouped query attention there are > 1 Q matrices per K, V matrix.
6916 uint32_t gqa_ratio = (uint32_t)ne12 / (uint32_t)ne02;
6917 if (gqa_ratio > 8 || gqa_ratio == 0 || ne12 != ne02 * gqa_ratio) {
6918 gqa_ratio = 1;
6919 }
6920
6921 {
6922 // Request descriptor sets
6923 ggml_pipeline_request_descriptor_sets(ctx, pipeline&: ctx->device->pipeline_mul_mat_vec_p021_f16_f32[gqa_ratio - 1], n: 1);
6924 }
6925
6926 vk_buffer d_D;
6927 uint64_t d_buf_offset = 0;
6928
6929 if (ctx->num_additional_fused_ops > 0) {
6930 const ggml_tensor * add = cgraph->nodes[node_idx + 1];
6931 ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)add->buffer->context;
6932 d_D = dst_buf_ctx->dev_buffer;
6933 d_buf_offset = vk_tensor_offset(tensor: add) + add->view_offs;
6934 } else {
6935 ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
6936 d_D = dst_buf_ctx->dev_buffer;
6937 d_buf_offset = vk_tensor_offset(tensor: dst) + dst->view_offs;
6938 }
6939 GGML_ASSERT(d_D != nullptr);
6940 vk_buffer d_Qx = src0_buf_ctx->dev_buffer;
6941 const uint64_t qx_buf_offset = vk_tensor_offset(tensor: src0) + src0->view_offs;
6942 GGML_ASSERT(d_Qx != nullptr);
6943 if (!src1_uma) {
6944 d_Qy = src1_buf_ctx->dev_buffer;
6945 qy_buf_offset = vk_tensor_offset(tensor: src1) + src1->view_offs;
6946 GGML_ASSERT(d_Qx != nullptr);
6947 }
6948
6949 const uint64_t qy_buffer_offset = (qy_buf_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment;
6950 const uint64_t qy_shader_offset = qy_buf_offset - qy_buffer_offset;
6951
6952 const uint64_t d_buffer_offset = (d_buf_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment;
6953 const uint64_t d_shader_offset = d_buf_offset - d_buffer_offset;
6954
6955 uint32_t enable_bias = ctx->num_additional_fused_ops > 0;
6956
6957 vk_buffer d_B = d_D;
6958 size_t b_buf_offset = 0;
6959 uint64_t b_sz = 0;
6960
6961 if (enable_bias) {
6962 const ggml_tensor * add = cgraph->nodes[node_idx + 1];
6963 const ggml_tensor * bias = add->src[0] == dst ? add->src[1] : add->src[0];
6964
6965 bool b_uma = false;
6966 if (ctx->device->uma) {
6967 ggml_vk_host_get(device: ctx->device, ptr: bias->data, buf&: d_B, buf_offset&: b_buf_offset);
6968 b_uma = d_B != nullptr;
6969 }
6970 if(!b_uma) {
6971 ggml_backend_vk_buffer_context * bias_buf_ctx = (ggml_backend_vk_buffer_context *)bias->buffer->context;
6972 d_B = bias_buf_ctx->dev_buffer;
6973 b_buf_offset = vk_tensor_offset(tensor: bias) + bias->view_offs;
6974 GGML_ASSERT(d_B != nullptr);
6975 b_sz = ggml_nbytes(tensor: bias);
6976 }
6977 }
6978
6979 // compute
6980 const std::array<uint32_t, 7> pc = { (uint32_t)ne00, (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne12, (uint32_t)(qy_shader_offset / ggml_type_size(type: src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(type: dst->type)), enable_bias };
6981
6982 uint32_t workgroups_z = (uint32_t)ne12;
6983 // When gqa_ratio > 1, each invocation does multiple rows and we can launch fewer workgroups
6984 if (gqa_ratio > 1) {
6985 workgroups_z /= gqa_ratio;
6986 }
6987
6988 ggml_vk_dispatch_pipeline(ctx, subctx, pipeline&: ctx->device->pipeline_mul_mat_vec_p021_f16_f32[gqa_ratio - 1],
6989 descriptor_buffer_infos: {
6990 vk_subbuffer{ .buffer: d_Qx, .offset: qx_buf_offset, .size: qx_sz },
6991 vk_subbuffer{ .buffer: d_Qy, .offset: qy_buffer_offset, .size: qy_sz + qy_shader_offset },
6992 vk_subbuffer{ .buffer: d_D, .offset: d_buffer_offset, .size: d_sz + d_shader_offset },
6993 vk_subbuffer{ .buffer: d_B, .offset: b_buf_offset, .size: b_sz },
6994 }, push_constants: pc, elements: { 1, (uint32_t)ne01, workgroups_z });
6995}
6996
6997static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) {
6998 ggml_tensor * dst = cgraph->nodes[node_idx];
6999 const ggml_tensor * src0 = dst->src[0];
7000 const ggml_tensor * src1 = dst->src[1];
7001 VK_LOG_DEBUG("ggml_vk_mul_mat_nc_f16_f32((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
7002 std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
7003 std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
7004 std::cerr << "))");
7005 GGML_ASSERT(!ggml_is_transposed(src0));
7006 GGML_ASSERT(!ggml_is_transposed(src1));
7007 GGML_ASSERT(!ggml_is_permuted(src0));
7008 GGML_ASSERT(src0->type == GGML_TYPE_F16);
7009 GGML_ASSERT(src1->type == GGML_TYPE_F32);
7010
7011 const uint64_t ne00 = src0->ne[0];
7012 const uint64_t ne01 = src0->ne[1];
7013 const uint64_t ne02 = src0->ne[2];
7014 const uint64_t ne03 = src0->ne[3];
7015
7016 const uint64_t nb01 = src0->nb[1];
7017 const uint64_t nb02 = src0->nb[2];
7018
7019 const uint64_t nb12 = src1->nb[2];
7020
7021 // const uint64_t ne10 = src1->ne[0];
7022 const uint64_t ne11 = src1->ne[1];
7023 const uint64_t ne12 = src1->ne[2];
7024 // const uint64_t ne13 = src1->ne[3];
7025
7026 const uint32_t nb03 = (uint32_t)(src0->nb[3] / sizeof(ggml_fp16_t));
7027 const uint32_t nb13 = (uint32_t)(src1->nb[3] / sizeof(float));
7028 const uint32_t nb23 = (uint32_t)(dst->nb[3] / sizeof(float));
7029
7030 GGML_ASSERT(ne11 == 1);
7031 GGML_ASSERT(src0->ne[3] == src1->ne[3]); // checked in supports_op
7032
7033 ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context;
7034 ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context;
7035
7036 vk_buffer d_Qy = nullptr;
7037 size_t qy_buf_offset = 0;
7038
7039 bool src1_uma = false;
7040
7041 if (ctx->device->uma) {
7042 ggml_vk_host_get(device: ctx->device, ptr: src1->data, buf&: d_Qy, buf_offset&: qy_buf_offset);
7043 src1_uma = d_Qy != nullptr;
7044 }
7045
7046 const uint64_t d_ne = ne01 * ne11 * ne12 * ne03;
7047
7048 const uint32_t row_stride_x = nb01 / sizeof(ggml_fp16_t);
7049 const uint32_t channel_stride_x = nb02 / sizeof(ggml_fp16_t);
7050 const uint32_t channel_stride_y = nb12 / sizeof(float);
7051
7052 const uint64_t qx_sz = ggml_nbytes(tensor: src0);
7053 const uint64_t qy_sz = ggml_nbytes(tensor: src1);
7054 const uint64_t d_sz = sizeof(float) * d_ne;
7055
7056 {
7057 // Request descriptor sets
7058 ggml_pipeline_request_descriptor_sets(ctx, pipeline&: ctx->device->pipeline_mul_mat_vec_nc_f16_f32, n: 1);
7059 }
7060
7061 vk_buffer d_D;
7062 uint64_t d_buf_offset = 0;
7063
7064 if (ctx->num_additional_fused_ops > 0) {
7065 const ggml_tensor * add = cgraph->nodes[node_idx + 1];
7066 ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)add->buffer->context;
7067 d_D = dst_buf_ctx->dev_buffer;
7068 d_buf_offset = vk_tensor_offset(tensor: add) + add->view_offs;
7069 } else {
7070 ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
7071 d_D = dst_buf_ctx->dev_buffer;
7072 d_buf_offset = vk_tensor_offset(tensor: dst) + dst->view_offs;
7073 }
7074
7075 GGML_ASSERT(d_D != nullptr);
7076 vk_buffer d_Qx = src0_buf_ctx->dev_buffer;
7077 const uint64_t qx_buf_offset = vk_tensor_offset(tensor: src0) + src0->view_offs;
7078 GGML_ASSERT(d_Qx != nullptr);
7079 if (!src1_uma) {
7080 d_Qy = src1_buf_ctx->dev_buffer;
7081 qy_buf_offset = vk_tensor_offset(tensor: src1) + src1->view_offs;
7082 GGML_ASSERT(d_Qx != nullptr);
7083 }
7084
7085 const uint64_t qy_buffer_offset = (qy_buf_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment;
7086 const uint64_t qy_shader_offset = qy_buf_offset - qy_buffer_offset;
7087
7088 const uint64_t d_buffer_offset = (d_buf_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment;
7089 const uint64_t d_shader_offset = d_buf_offset - d_buffer_offset;
7090
7091 uint32_t enable_bias = ctx->num_additional_fused_ops > 0;
7092
7093 vk_buffer d_B = d_D;
7094 size_t b_buf_offset = 0;
7095 uint64_t b_sz = 0;
7096
7097 if (enable_bias) {
7098 const ggml_tensor * add = cgraph->nodes[node_idx + 1];
7099 const ggml_tensor * bias = add->src[0] == dst ? add->src[1] : add->src[0];
7100
7101 bool b_uma = false;
7102 if (ctx->device->uma) {
7103 ggml_vk_host_get(device: ctx->device, ptr: bias->data, buf&: d_B, buf_offset&: b_buf_offset);
7104 b_uma = d_B != nullptr;
7105 }
7106 if(!b_uma) {
7107 ggml_backend_vk_buffer_context * bias_buf_ctx = (ggml_backend_vk_buffer_context *)bias->buffer->context;
7108 d_B = bias_buf_ctx->dev_buffer;
7109 b_buf_offset = vk_tensor_offset(tensor: bias) + bias->view_offs;
7110 GGML_ASSERT(d_B != nullptr);
7111 b_sz = ggml_nbytes(tensor: bias);
7112 }
7113 }
7114
7115 // compute
7116 const std::array<uint32_t, 13> pc = { (uint32_t)ne00, (uint32_t)ne01, row_stride_x, channel_stride_x, channel_stride_y, (uint32_t)(ne12 / ne02), (uint32_t)ne12, (uint32_t)(qy_shader_offset / ggml_type_size(type: src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(type: dst->type)), nb03, nb13, nb23, enable_bias };
7117 ggml_vk_dispatch_pipeline(ctx, subctx, pipeline&: ctx->device->pipeline_mul_mat_vec_nc_f16_f32,
7118 descriptor_buffer_infos: {
7119 vk_subbuffer{ .buffer: d_Qx, .offset: qx_buf_offset, .size: qx_sz },
7120 vk_subbuffer{ .buffer: d_Qy, .offset: qy_buffer_offset, .size: qy_sz + qy_shader_offset },
7121 vk_subbuffer{ .buffer: d_D, .offset: d_buffer_offset, .size: d_sz + d_shader_offset },
7122 vk_subbuffer{ .buffer: d_B, .offset: b_buf_offset, .size: b_sz },
7123 }, push_constants: pc, elements: { (uint32_t)ne03, (uint32_t)ne01, (uint32_t)ne12 });
7124}
7125
7126static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) {
7127 ggml_tensor * dst = cgraph->nodes[node_idx];
7128 ggml_tensor * src0 = dst->src[0];
7129 ggml_tensor * src1 = dst->src[1];
7130 VK_LOG_DEBUG("ggml_vk_mul_mat(" << src0 << ", " << src1 << ", " << dst << ")");
7131
7132 // Handle huge A matrix by splitting the M dimensions. This works well for convolution use cases
7133 // where the M dimension is very large.
7134 // Split_k doesn't work with M splitting.
7135 const size_t nbytes = ggml_nbytes(tensor: src0);
7136 const bool needs_split = nbytes > ctx->device->properties.limits.maxStorageBufferRange;
7137 if (needs_split) {
7138 // Choose the number of rows that can fit (and divide by two, to allow for any additional offsets)
7139 const uint32_t M_split = ctx->device->properties.limits.maxStorageBufferRange / (2 * src0->nb[1]);
7140 uint32_t m_offset = 0;
7141 while (m_offset < dst->ne[0]) {
7142 const uint32_t cur_M_size = std::min(a: M_split, b: (uint32_t)(dst->ne[0] - m_offset));
7143 ggml_tensor dst2 = *dst;
7144 ggml_tensor src02 = *src0;
7145
7146 dst2.view_src = dst->view_src ? dst->view_src : dst;
7147 src02.view_src = src0->view_src ? src0->view_src : src0;
7148
7149 dst2.view_offs += m_offset * dst->nb[0];
7150 src02.view_offs += m_offset * src0->nb[1];
7151 dst2.ne[0] = cur_M_size;
7152 src02.ne[1] = cur_M_size;
7153
7154 ggml_vk_mul_mat_q_f16(ctx, subctx, src0: &src02, src1, dst: &dst2, disable_split_k: true);
7155
7156 m_offset += cur_M_size;
7157 }
7158 } else if (src0->type == GGML_TYPE_F16 && ggml_is_permuted(tensor: src0) && ggml_is_permuted(tensor: src1) && dst->ne[1] == 1 &&
7159 // detect 0213 permutation, and batch size of 1
7160 src0->nb[0] <= src0->nb[2] &&
7161 src0->nb[2] <= src0->nb[1] &&
7162 src0->nb[1] <= src0->nb[3] &&
7163 src1->nb[0] <= src1->nb[2] &&
7164 src1->nb[2] <= src1->nb[1] &&
7165 src1->nb[1] <= src1->nb[3] &&
7166 src0->ne[3] == 1 &&
7167 src1->ne[3] == 1) {
7168 ggml_vk_mul_mat_vec_p021_f16_f32(ctx, subctx, cgraph, node_idx);
7169 } else if (src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(tensor: src0) && !ggml_is_transposed(tensor: src1) && dst->ne[1] == 1 &&
7170 !ggml_is_permuted(tensor: src0) && !ggml_is_permuted(tensor: src1)) {
7171 ggml_vk_mul_mat_vec_nc_f16_f32(ctx, subctx, cgraph, node_idx);
7172 // mul_mat_vec supports batching ne12*ne13 when ne11==1, or treating ne11 as the batch size (up to four)
7173 // when ne12 and ne13 are one.
7174 } else if ((dst->ne[1] == 1 || (dst->ne[1] <= mul_mat_vec_max_cols && src1->ne[2] * src1->ne[3] == 1)) &&
7175 (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16 || ggml_is_quantized(type: src0->type))) {
7176 ggml_vk_mul_mat_vec_q_f16(ctx, subctx, cgraph, node_idx);
7177 } else {
7178 ggml_vk_mul_mat_q_f16(ctx, subctx, src0, src1, dst, disable_split_k: false);
7179 }
7180}
7181
7182static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
7183 VK_LOG_DEBUG("ggml_vk_mul_mat_id_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
7184 std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
7185 std::cerr << "), (" << ids << ", name=" << ids->name << ", type=" << ids->type << ", ne0=" << ids->ne[0] << ", ne1=" << ids->ne[1] << ", ne2=" << ids->ne[2] << ", ne3=" << ids->ne[3] << ", nb0=" << ids->nb[0] << ", nb1=" << ids->nb[1] << ", nb2=" << ids->nb[2] << ", nb3=" << ids->nb[3];
7186 std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3] << "),)");
7187 GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT
7188 GGML_ASSERT(ids->type == GGML_TYPE_I32);
7189
7190 const uint64_t ne00 = src0->ne[0];
7191 const uint64_t ne01 = src0->ne[1];
7192 const uint64_t ne02 = src0->ne[2];
7193 const uint64_t ne03 = src0->ne[3];
7194
7195 const uint64_t ne10 = src1->ne[0];
7196 const uint64_t ne11 = src1->ne[1];
7197 const uint64_t ne12 = src1->ne[2];
7198 const uint64_t ne13 = src1->ne[3];
7199
7200 const uint64_t nei0 = ids->ne[0];
7201 const uint64_t nei1 = ids->ne[1];
7202
7203 const uint32_t nbi1 = ids->nb[1];
7204 const uint32_t nbi2 = ids->nb[2];
7205
7206 const uint64_t ne20 = dst->ne[0];
7207 const uint64_t ne21 = dst->ne[1];
7208 const uint64_t ne22 = dst->ne[2];
7209 const uint64_t ne23 = dst->ne[3];
7210
7211 const uint64_t n_as = ne02;
7212
7213 ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
7214 ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context;
7215 ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context;
7216 ggml_backend_vk_buffer_context * ids_buf_ctx = (ggml_backend_vk_buffer_context *)ids->buffer->context;
7217
7218 vk_buffer d_Qx = nullptr;
7219 size_t qx_buf_offset = 0;
7220 vk_buffer d_Qy = nullptr;
7221 size_t qy_buf_offset = 0;
7222 vk_buffer d_ids = nullptr;
7223 size_t ids_buf_offset = 0;
7224
7225 bool src0_uma = false;
7226 bool src1_uma = false;
7227 bool ids_uma = false;
7228
7229 if (ctx->device->uma) {
7230 ggml_vk_host_get(device: ctx->device, ptr: src0->data, buf&: d_Qx, buf_offset&: qx_buf_offset);
7231 ggml_vk_host_get(device: ctx->device, ptr: src1->data, buf&: d_Qy, buf_offset&: qy_buf_offset);
7232 ggml_vk_host_get(device: ctx->device, ptr: ids->data, buf&: d_ids, buf_offset&: ids_buf_offset);
7233 src0_uma = d_Qx != nullptr;
7234 src1_uma = d_Qy != nullptr;
7235 ids_uma = d_ids != nullptr;
7236 }
7237
7238 // Reformat and convert to fp16 if non-contiguous, or for coopmat2 for better perf
7239 const bool x_non_contig = (ctx->device->coopmat2 && src0->type == GGML_TYPE_F32) ||
7240 !ggml_vk_dim01_contiguous(tensor: src0);
7241 const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) ||
7242 (src0->type == GGML_TYPE_BF16 && src1->type != GGML_TYPE_BF16) ||
7243 !ggml_vk_dim01_contiguous(tensor: src1);
7244
7245 // If src0 is BF16, try to use a BF16 x BF16 multiply
7246 ggml_type f16_type = src0->type == GGML_TYPE_BF16 ? GGML_TYPE_BF16 : GGML_TYPE_F16;
7247
7248 const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
7249
7250 bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(tensor: src1) && (ne11 * ne10) % 4 == 0;
7251
7252 // Check for mmq first
7253 vk_matmul_pipeline mmp = quantize_y ? ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0_type: src0->type, src1_type: GGML_TYPE_Q8_1, prec: (ggml_prec)dst->op_params[0]) : nullptr;
7254
7255 if (mmp == nullptr) {
7256 // Fall back to f16 dequant mul mat
7257 mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0_type: src0->type, src1_type: y_non_contig ? f16_type : src1->type, prec: (ggml_prec)dst->op_params[0]);
7258 quantize_y = false;
7259 }
7260
7261 const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
7262 const bool qy_needs_dequant = !quantize_y && ((src1->type != f16_type && !y_f32_kernel) || y_non_contig);
7263
7264 if (qx_needs_dequant) {
7265 // Fall back to dequant + f16 mulmat
7266 mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0_type: f16_type, src1_type: y_f32_kernel ? GGML_TYPE_F32 : f16_type, prec: (ggml_prec)dst->op_params[0]);
7267 }
7268
7269 // Not implemented
7270 GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
7271
7272 const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(width: ne10, align: ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, m: ne01, n: nei1, src0_type: qx_needs_dequant ? f16_type : src0->type));
7273 const bool aligned = !quantize_y && ne10 == kpad && ne01 > 8 && nei1 > 8;
7274
7275 vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, m: ne01, n: nei1, aligned, src0_type: qx_needs_dequant ? f16_type : src0->type);
7276
7277 // Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking
7278 uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) :ne11;
7279 const uint64_t x_ne = ne01 * ne00;
7280 const uint64_t y_ne = padded_n * ne10;
7281 const uint64_t d_ne = ne21 * ne20;
7282
7283 const uint64_t qx_sz = ggml_type_size(type: src0->type) * x_ne / ggml_blck_size(type: src0->type);
7284 const uint64_t qy_sz = ggml_type_size(type: src1->type) * y_ne / ggml_blck_size(type: src1->type);
7285 const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne;
7286 const uint64_t y_sz = quantize_y ? (y_ne * ggml_type_size(type: GGML_TYPE_Q8_1) / ggml_blck_size(type: GGML_TYPE_Q8_1)) : (y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne);
7287 const uint64_t ids_sz = nbi2;
7288 const uint64_t d_sz = sizeof(float) * d_ne;
7289
7290 vk_pipeline to_fp16_vk_0 = nullptr;
7291 vk_pipeline to_fp16_vk_1 = nullptr;
7292 vk_pipeline to_q8_1 = nullptr;
7293
7294 if (x_non_contig) {
7295 to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src: src0, dst: nullptr, to: f16_type);
7296 } else {
7297 to_fp16_vk_0 = ggml_vk_get_to_fp16(ctx, type: src0->type);
7298 }
7299 if (y_non_contig) {
7300 to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src: src1, dst: nullptr, to: f16_type);
7301 } else {
7302 to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, type: src1->type);
7303 }
7304 GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT
7305 GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT
7306
7307 if (quantize_y) {
7308 to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, type: GGML_TYPE_Q8_1, use_x4_blocks: true);
7309 }
7310
7311 {
7312 const uint64_t x_sz_upd = x_sz * ne02 * ne03;
7313 uint64_t y_sz_upd = y_sz * ne12 * ne13;
7314 if (quantize_y) {
7315 y_sz_upd = CEIL_DIV(y_sz_upd, 144) * 144;
7316 }
7317 if (
7318 (qx_needs_dequant && x_sz_upd > ctx->device->properties.limits.maxStorageBufferRange) ||
7319 (qy_needs_dequant && y_sz_upd > ctx->device->properties.limits.maxStorageBufferRange)) {
7320 GGML_ABORT("Requested preallocation size is too large");
7321 }
7322 if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) {
7323 ctx->prealloc_size_x = x_sz_upd;
7324 ggml_vk_preallocate_buffers(ctx, subctx);
7325 }
7326 if ((qy_needs_dequant || quantize_y) && ctx->prealloc_size_y < y_sz_upd) {
7327 ctx->prealloc_size_y = y_sz_upd;
7328 ggml_vk_preallocate_buffers(ctx, subctx);
7329 }
7330
7331 // Request descriptor sets
7332 ggml_pipeline_request_descriptor_sets(ctx, pipeline, n: 1);
7333 if (qx_needs_dequant) {
7334 ggml_pipeline_request_descriptor_sets(ctx, pipeline&: to_fp16_vk_0, n: 1);
7335 }
7336 if (qy_needs_dequant) {
7337 ggml_pipeline_request_descriptor_sets(ctx, pipeline&: to_fp16_vk_1, n: 1);
7338 }
7339 if (quantize_y) {
7340 ggml_pipeline_request_descriptor_sets(ctx, pipeline&: to_q8_1, n: 1);
7341 }
7342 }
7343
7344 vk_buffer d_D = dst_buf_ctx->dev_buffer;
7345 const uint64_t d_buf_offset = vk_tensor_offset(tensor: dst) + dst->view_offs;
7346 GGML_ASSERT(d_D != nullptr);
7347 vk_buffer d_X;
7348 uint64_t x_buf_offset = 0;
7349 vk_buffer d_Y;
7350 uint64_t y_buf_offset = 0;
7351 if (!src0_uma) {
7352 d_Qx = src0_buf_ctx->dev_buffer;
7353 qx_buf_offset = vk_tensor_offset(tensor: src0) + src0->view_offs;
7354 GGML_ASSERT(d_Qx != nullptr);
7355 }
7356 if (!src1_uma) {
7357 d_Qy = src1_buf_ctx->dev_buffer;
7358 qy_buf_offset = vk_tensor_offset(tensor: src1) + src1->view_offs;
7359 GGML_ASSERT(d_Qy != nullptr);
7360 }
7361 if (!ids_uma) {
7362 d_ids = ids_buf_ctx->dev_buffer;
7363 ids_buf_offset = vk_tensor_offset(tensor: ids) + ids->view_offs;
7364 GGML_ASSERT(d_ids != nullptr);
7365 }
7366 if (qx_needs_dequant) {
7367 d_X = ctx->prealloc_x;
7368 GGML_ASSERT(d_X->size >= x_sz * ne02 * ne03);
7369 } else {
7370 d_X = d_Qx;
7371 x_buf_offset = qx_buf_offset;
7372 GGML_ASSERT(qx_sz == x_sz);
7373 }
7374 if (qy_needs_dequant) {
7375 d_Y = ctx->prealloc_y;
7376 GGML_ASSERT(d_Y->size >= y_sz * ne12 * ne13);
7377 } else if (quantize_y) {
7378 d_Y = ctx->prealloc_y;
7379 GGML_ASSERT(d_Y->size >= CEIL_DIV(y_sz * ne12 * ne13, 144) * 144);
7380 } else {
7381 d_Y = d_Qy;
7382 y_buf_offset = qy_buf_offset;
7383 GGML_ASSERT(qy_sz == y_sz);
7384 }
7385
7386 if (x_non_contig || qx_needs_dequant) {
7387 if (ctx->prealloc_x_need_sync) {
7388 ggml_vk_sync_buffers(ctx, subctx);
7389 }
7390 }
7391
7392 if (x_non_contig) {
7393 ggml_vk_cpy_to_contiguous(ctx, subctx, pipeline: to_fp16_vk_0, tensor: src0, in: ggml_vk_subbuffer(ctx, buf: d_Qx, offset: qx_buf_offset), out: ggml_vk_subbuffer(ctx, buf: d_X, offset: 0));
7394 } else if (qx_needs_dequant) {
7395 const std::vector<uint32_t> pc = { (uint32_t)ne01, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(ggml_nelements(tensor: src0)) };
7396 ggml_vk_dispatch_pipeline(ctx, subctx, pipeline&: to_fp16_vk_0,
7397 descriptor_buffer_infos: { vk_subbuffer{ .buffer: d_Qx, .offset: qx_buf_offset, .size: qx_sz * ne02 * ne03 }, vk_subbuffer{ .buffer: d_X, .offset: 0, .size: x_sz * ne02 * ne03 } }, push_constants: pc, elements: { (uint32_t)(x_ne * ne02 * ne03), 1, 1});
7398 ggml_vk_sync_buffers(ctx, subctx);
7399 }
7400 if (y_non_contig) {
7401 if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() ||
7402 ctx->prealloc_y_last_tensor_used != src1) {
7403 if (ctx->prealloc_y_need_sync) {
7404 ggml_vk_sync_buffers(ctx, subctx);
7405 }
7406 ggml_vk_cpy_to_contiguous(ctx, subctx, pipeline: to_fp16_vk_1, tensor: src1, in: ggml_vk_subbuffer(ctx, buf: d_Qy, offset: qy_buf_offset), out: ggml_vk_subbuffer(ctx, buf: d_Y, offset: 0));
7407 ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get();
7408 ctx->prealloc_y_last_tensor_used = src1;
7409 }
7410 }
7411 if (quantize_y) {
7412 if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() ||
7413 ctx->prealloc_y_last_tensor_used != src1) {
7414 if (ctx->prealloc_y_need_sync) {
7415 ggml_vk_sync_buffers(ctx, subctx);
7416 }
7417 ggml_vk_quantize_q8_1(ctx, subctx, in: ggml_vk_subbuffer(ctx, buf: d_Qy, offset: qy_buf_offset), out: ggml_vk_subbuffer(ctx, buf: d_Y, offset: 0), ne: y_ne * ne12 * ne13, use_x4_blocks: true);
7418 ctx->prealloc_y_last_pipeline_used = to_q8_1.get();
7419 ctx->prealloc_y_last_tensor_used = src1;
7420 }
7421 }
7422
7423 uint32_t stride_batch_x = ne00*ne01;
7424 uint32_t stride_batch_y = ne10*ne11;
7425
7426 if (!ggml_vk_dim01_contiguous(tensor: src0) && !qx_needs_dequant) {
7427 stride_batch_x = src0->nb[0] / ggml_type_size(type: src0->type);
7428 }
7429
7430 if (!ggml_vk_dim01_contiguous(tensor: src1) && !qy_needs_dequant && !quantize_y) {
7431 stride_batch_y = src1->nb[0] / ggml_type_size(type: src1->type);
7432 }
7433
7434 uint32_t y_sz_total = y_sz * ne12 * ne13;
7435 if (quantize_y) {
7436 y_sz_total = CEIL_DIV(y_sz_total, 144) * 144;
7437 }
7438
7439 // compute
7440 ggml_vk_matmul_id(
7441 ctx, subctx, pipeline,
7442 a: { .buffer: d_X, .offset: x_buf_offset, .size: x_sz * ne02 * ne03 }, b: { .buffer: d_Y, .offset: y_buf_offset, .size: y_sz_total },
7443 d: { .buffer: d_D, .offset: d_buf_offset, .size: d_sz * ne22 * ne23 }, ids: { .buffer: d_ids, .offset: ids_buf_offset, .size: ids_sz },
7444 m: ne01, n: ne21, k: ne10, stride_a: ne10, stride_b: ne10, stride_d: ne01,
7445 batch_stride_a: stride_batch_x, batch_stride_b: stride_batch_y, batch_stride_d: ne20*ne21,
7446 n_as, nei0, nei1, nbi1: nbi1 / ggml_type_size(type: ids->type), ne11, padded_n
7447 ); // NOLINT
7448
7449 if (x_non_contig || qx_needs_dequant) {
7450 ctx->prealloc_x_need_sync = true;
7451 }
7452 if (y_non_contig) {
7453 ctx->prealloc_y_need_sync = true;
7454 }
7455}
7456
7457static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) {
7458 ggml_tensor * dst = cgraph->nodes[node_idx];
7459 ggml_tensor * src0 = dst->src[0];
7460 ggml_tensor * src1 = dst->src[1];
7461 ggml_tensor * ids = dst->src[2];
7462 VK_LOG_DEBUG("ggml_vk_mul_mat_vec_id_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
7463 std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
7464 std::cerr << "), (" << ids << ", name=" << ids->name << ", type=" << ids->type << ", ne0=" << ids->ne[0] << ", ne1=" << ids->ne[1] << ", ne2=" << ids->ne[2] << ", ne3=" << ids->ne[3] << ", nb0=" << ids->nb[0] << ", nb1=" << ids->nb[1] << ", nb2=" << ids->nb[2] << ", nb3=" << ids->nb[3];
7465 std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
7466 std::cerr << "))");
7467 GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16); // NOLINT
7468 GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT
7469 GGML_ASSERT(ids->type == GGML_TYPE_I32);
7470
7471 const uint64_t ne00 = src0->ne[0];
7472 const uint64_t ne01 = src0->ne[1];
7473 const uint64_t ne02 = src0->ne[2];
7474 const uint64_t ne03 = src0->ne[3];
7475
7476 const uint64_t ne10 = src1->ne[0];
7477 const uint64_t ne11 = src1->ne[1];
7478 const uint64_t ne12 = src1->ne[2];
7479 const uint64_t ne13 = src1->ne[3];
7480
7481 const uint64_t nei0 = ids->ne[0];
7482 const uint64_t nei1 = ids->ne[1];
7483
7484 const uint64_t nbi2 = ids->nb[2];
7485
7486 GGML_ASSERT(nei1 == 1);
7487
7488 const uint64_t ne20 = dst->ne[0];
7489 const uint64_t ne21 = dst->ne[1];
7490 const uint64_t ne22 = dst->ne[2];
7491 const uint64_t ne23 = dst->ne[3];
7492
7493 ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context;
7494 ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context;
7495 ggml_backend_vk_buffer_context * ids_buf_ctx = (ggml_backend_vk_buffer_context *)ids->buffer->context;
7496
7497 vk_buffer d_Qx = nullptr;
7498 size_t qx_buf_offset = 0;
7499 vk_buffer d_Qy = nullptr;
7500 size_t qy_buf_offset = 0;
7501 vk_buffer d_ids = nullptr;
7502 size_t ids_buf_offset = 0;
7503
7504 bool src0_uma = false;
7505 bool src1_uma = false;
7506 bool ids_uma = false;
7507
7508 if (ctx->device->uma) {
7509 ggml_vk_host_get(device: ctx->device, ptr: src0->data, buf&: d_Qx, buf_offset&: qx_buf_offset);
7510 ggml_vk_host_get(device: ctx->device, ptr: src1->data, buf&: d_Qy, buf_offset&: qy_buf_offset);
7511 ggml_vk_host_get(device: ctx->device, ptr: ids->data, buf&: d_ids, buf_offset&: ids_buf_offset);
7512 src0_uma = d_Qx != nullptr;
7513 src1_uma = d_Qy != nullptr;
7514 ids_uma = d_ids != nullptr;
7515 }
7516
7517 const bool x_non_contig = !ggml_vk_dim01_contiguous(tensor: src0);
7518 const bool y_non_contig = !ggml_vk_dim01_contiguous(tensor: src1);
7519
7520 const bool f16_f32_kernel = src1->type == GGML_TYPE_F32;
7521
7522 const bool qx_needs_dequant = x_non_contig;
7523 const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !f16_f32_kernel) || y_non_contig;
7524
7525 // Not implemented
7526 GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
7527
7528 const uint64_t x_ne = ne01 * ne00;
7529 const uint64_t y_ne = ne11 * ne10;
7530 const uint64_t d_ne = ne21 * ne20;
7531
7532 const uint64_t qx_sz = ggml_vk_align_size(width: ggml_type_size(type: src0->type) * x_ne / ggml_blck_size(type: src0->type), align: ctx->device->properties.limits.minStorageBufferOffsetAlignment);
7533 const uint64_t qy_sz = ggml_type_size(type: src1->type) * y_ne / ggml_blck_size(type: src1->type);
7534 const uint64_t x_sz = x_non_contig ? ggml_vk_align_size(width: ggml_type_size(type: src0->type) * x_ne, align: ctx->device->properties.limits.minStorageBufferOffsetAlignment) : qx_sz;
7535 const uint64_t y_sz = f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne;
7536 const uint64_t ids_sz = nbi2;
7537 const uint64_t d_sz = sizeof(float) * d_ne;
7538
7539 vk_pipeline to_fp16_vk_0 = nullptr;
7540 vk_pipeline to_fp16_vk_1 = nullptr;
7541 if (x_non_contig) {
7542 to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src: src0, dst: nullptr, to: src0->type);
7543 }
7544 if (y_non_contig) {
7545 to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src: src1, dst: nullptr, to: src1->type);
7546 } else {
7547 to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, type: src1->type);
7548 }
7549 vk_pipeline dmmv = ggml_vk_get_dequantize_mul_mat_vec_id(ctx, a_type: src0->type, b_type: src1->type);
7550 GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT
7551 GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT
7552 GGML_ASSERT(dmmv != nullptr);
7553
7554 {
7555 const uint64_t x_sz_upd = x_sz * ne02 * ne03;
7556 const uint64_t y_sz_upd = y_sz * ne12 * ne13;
7557 if (
7558 (qx_needs_dequant && x_sz_upd > ctx->device->properties.limits.maxStorageBufferRange) ||
7559 (qy_needs_dequant && y_sz_upd > ctx->device->properties.limits.maxStorageBufferRange)) {
7560 GGML_ABORT("Requested preallocation size is too large");
7561 }
7562 if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) {
7563 ctx->prealloc_size_x = x_sz_upd;
7564 ggml_vk_preallocate_buffers(ctx, subctx);
7565 }
7566 if (qy_needs_dequant && ctx->prealloc_size_y < y_sz_upd) {
7567 ctx->prealloc_size_y = y_sz_upd;
7568 ggml_vk_preallocate_buffers(ctx, subctx);
7569 }
7570
7571 // Request descriptor sets
7572 if (qx_needs_dequant) {
7573 ggml_pipeline_request_descriptor_sets(ctx, pipeline&: to_fp16_vk_0, n: 1);
7574 }
7575 if (qy_needs_dequant) {
7576 ggml_pipeline_request_descriptor_sets(ctx, pipeline&: to_fp16_vk_1, n: 1);
7577 }
7578 ggml_pipeline_request_descriptor_sets(ctx, pipeline&: dmmv, n: 1);
7579 }
7580
7581 vk_buffer d_D;
7582 uint64_t d_buf_offset = 0;
7583
7584 if (ctx->num_additional_fused_ops > 0) {
7585 const ggml_tensor * add = cgraph->nodes[node_idx + 1];
7586 ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)add->buffer->context;
7587 d_D = dst_buf_ctx->dev_buffer;
7588 d_buf_offset = vk_tensor_offset(tensor: add) + add->view_offs;
7589 } else {
7590 ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
7591 d_D = dst_buf_ctx->dev_buffer;
7592 d_buf_offset = vk_tensor_offset(tensor: dst) + dst->view_offs;
7593 }
7594
7595 GGML_ASSERT(d_D != nullptr);
7596 vk_buffer d_X;
7597 uint64_t x_buf_offset = 0;
7598 vk_buffer d_Y;
7599 uint64_t y_buf_offset = 0;
7600 if(!src0_uma) {
7601 d_Qx = src0_buf_ctx->dev_buffer;
7602 qx_buf_offset = vk_tensor_offset(tensor: src0) + src0->view_offs;
7603 GGML_ASSERT(d_Qx != nullptr);
7604 }
7605 if(!src1_uma) {
7606 d_Qy = src1_buf_ctx->dev_buffer;
7607 qy_buf_offset = vk_tensor_offset(tensor: src1) + src1->view_offs;
7608 GGML_ASSERT(d_Qy != nullptr);
7609 }
7610 if(!ids_uma) {
7611 d_ids = ids_buf_ctx->dev_buffer;
7612 ids_buf_offset = vk_tensor_offset(tensor: ids) + ids->view_offs;
7613 GGML_ASSERT(d_ids != nullptr);
7614 }
7615 if (qx_needs_dequant) {
7616 d_X = ctx->prealloc_x;
7617 } else {
7618 d_X = d_Qx;
7619 x_buf_offset = qx_buf_offset;
7620 GGML_ASSERT(qx_sz == x_sz);
7621 }
7622 if (qy_needs_dequant) {
7623 d_Y = ctx->prealloc_y;
7624 } else {
7625 d_Y = d_Qy;
7626 y_buf_offset = qy_buf_offset;
7627 GGML_ASSERT(qy_sz == y_sz);
7628 }
7629
7630 if (x_non_contig) {
7631 if (ctx->prealloc_x_need_sync) {
7632 ggml_vk_sync_buffers(ctx, subctx);
7633 }
7634 }
7635
7636 if (x_non_contig) {
7637 GGML_ASSERT(x_sz == ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment));
7638 ggml_vk_cpy_to_contiguous(ctx, subctx, pipeline: to_fp16_vk_0, tensor: src0, in: ggml_vk_subbuffer(ctx, buf: d_Qx, offset: qx_buf_offset), out: ggml_vk_subbuffer(ctx, buf: d_X, offset: 0));
7639 }
7640 if (y_non_contig) {
7641 GGML_ASSERT(y_sz == ggml_type_size(src1->type) * y_ne);
7642 if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() ||
7643 ctx->prealloc_y_last_tensor_used != src1) {
7644 if (ctx->prealloc_y_need_sync) {
7645 ggml_vk_sync_buffers(ctx, subctx);
7646 }
7647 ggml_vk_cpy_to_contiguous(ctx, subctx, pipeline: to_fp16_vk_1, tensor: src1, in: ggml_vk_subbuffer(ctx, buf: d_Qy, offset: qy_buf_offset), out: ggml_vk_subbuffer(ctx, buf: d_Y, offset: 0));
7648 ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get();
7649 ctx->prealloc_y_last_tensor_used = src1;
7650 }
7651 }
7652
7653 uint32_t stride_batch_y = ne10*ne11;
7654
7655 if (!ggml_vk_dim01_contiguous(tensor: src1) && !qy_needs_dequant) {
7656 stride_batch_y = src1->nb[0] / ggml_type_size(type: src1->type);
7657 }
7658
7659 const uint32_t max_groups_x = ctx->device->properties.limits.maxComputeWorkGroupCount[0];
7660
7661 uint32_t groups_x = ne01;
7662 uint32_t groups_z = 1;
7663
7664 if (ne01 > max_groups_x) {
7665 groups_z = 64;
7666 groups_x = CEIL_DIV(groups_x, groups_z);
7667 }
7668
7669 uint32_t enable_bias = ctx->num_additional_fused_ops > 0;
7670
7671 vk_buffer d_B = d_D;
7672 size_t b_buf_offset = 0;
7673 uint64_t b_sz = 0;
7674
7675 if (enable_bias) {
7676 const ggml_tensor * bias = cgraph->nodes[node_idx + 1]->src[1];
7677
7678 bool b_uma = false;
7679 if (ctx->device->uma) {
7680 ggml_vk_host_get(device: ctx->device, ptr: bias->data, buf&: d_B, buf_offset&: b_buf_offset);
7681 b_uma = d_B != nullptr;
7682 }
7683 if(!b_uma) {
7684 ggml_backend_vk_buffer_context * bias_buf_ctx = (ggml_backend_vk_buffer_context *)bias->buffer->context;
7685 d_B = bias_buf_ctx->dev_buffer;
7686 b_buf_offset = vk_tensor_offset(tensor: bias) + bias->view_offs;
7687 GGML_ASSERT(d_B != nullptr);
7688 b_sz = ggml_nbytes(tensor: bias);
7689 }
7690 }
7691
7692 // compute
7693 const vk_mat_vec_id_push_constants pc = {
7694 .ncols: (uint32_t)ne00, .stride_a: (uint32_t)ne10, .stride_b: (uint32_t)ne10, .stride_d: (uint32_t)ne01,
7695 .batch_stride_a: (uint32_t)x_ne, .batch_stride_b: stride_batch_y, .batch_stride_d: (uint32_t)(ne20*ne21),
7696
7697 .enable_bias: enable_bias,
7698
7699 .nei0: (uint32_t)nei0, .ne11: (uint32_t)ne11,
7700 };
7701 ggml_vk_dispatch_pipeline(ctx, subctx, pipeline&: dmmv,
7702 descriptor_buffer_infos: {
7703 vk_subbuffer{ .buffer: d_X, .offset: x_buf_offset, .size: x_sz * ne02 * ne03 },
7704 vk_subbuffer{ .buffer: d_Y, .offset: y_buf_offset, .size: y_sz * ne12 * ne13 },
7705 vk_subbuffer{ .buffer: d_D, .offset: d_buf_offset, .size: d_sz * ne22 * ne23},
7706 vk_subbuffer{ .buffer: d_B, .offset: b_buf_offset, .size: b_sz },
7707 vk_subbuffer{ .buffer: d_ids, .offset: ids_buf_offset, .size: ids_sz },
7708 },
7709 push_constants: pc, elements: { groups_x, (uint32_t)nei0, groups_z });
7710
7711 if (x_non_contig) {
7712 ctx->prealloc_x_need_sync = true;
7713 }
7714 if (y_non_contig) {
7715 ctx->prealloc_y_need_sync = true;
7716 }
7717}
7718
7719static bool ggml_vk_use_mul_mat_vec_id(const struct ggml_cgraph * cgraph, int node_idx) {
7720 ggml_tensor * dst = cgraph->nodes[node_idx];
7721 ggml_tensor * src0 = dst->src[0];
7722 ggml_tensor * src2 = dst->src[2];
7723 return src2->ne[1] == 1 && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(type: src0->type));
7724}
7725
7726static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) {
7727 ggml_tensor * dst = cgraph->nodes[node_idx];
7728 ggml_tensor * src0 = dst->src[0];
7729 ggml_tensor * src1 = dst->src[1];
7730 ggml_tensor * src2 = dst->src[2];
7731 VK_LOG_DEBUG("ggml_vk_mul_mat_id(" << src0 << ", " << src1 << ", " << src2 << ", " << dst << ")");
7732 if (ggml_vk_use_mul_mat_vec_id(cgraph, node_idx)) {
7733 ggml_vk_mul_mat_vec_id_q_f16(ctx, subctx, cgraph, node_idx);
7734 } else {
7735 ggml_vk_mul_mat_id_q_f16(ctx, subctx, src0, src1, ids: src2, dst);
7736 }
7737}
7738
7739static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv) {
7740 // Needs to be kept up to date on shader changes
7741 GGML_UNUSED(hsv);
7742 const uint32_t wg_size = scalar_flash_attention_workgroup_size;
7743 const uint32_t Br = get_fa_scalar_num_large_rows(hsv);
7744 const uint32_t Bc = scalar_flash_attention_Bc;
7745
7746 const uint32_t tmpsh = wg_size * sizeof(float);
7747 const uint32_t tmpshv4 = wg_size * 4 * sizeof(float);
7748
7749 const uint32_t masksh = Bc * Br * sizeof(float);
7750
7751 const uint32_t Qf = Br * (hsk / 4 + 2) * 4 * sizeof(float);
7752
7753 const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf;
7754 const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
7755
7756 VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", total_size=" << total_size << ", supported=" << supported);
7757
7758 return supported;
7759}
7760
7761static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, bool f32acc) {
7762 // Needs to be kept up to date on shader changes
7763 GGML_UNUSED(hsv);
7764 const uint32_t wg_size = scalar_flash_attention_workgroup_size;
7765 const uint32_t Br = coopmat1_flash_attention_num_large_rows;
7766 const uint32_t Bc = scalar_flash_attention_Bc;
7767
7768 const uint32_t hsk_pad = ROUNDUP_POW2(hsk, 16);
7769
7770 const uint32_t acctype = f32acc ? 4 : 2;
7771 const uint32_t f16vec4 = 8;
7772
7773 const uint32_t tmpsh = wg_size * sizeof(float);
7774 const uint32_t tmpshv4 = wg_size * 4 * acctype;
7775
7776 const uint32_t qstride = hsk_pad / 4 + 2;
7777 const uint32_t Qf = Br * qstride * f16vec4;
7778
7779 const uint32_t sfshstride = (hsk <= 128) ? (Br + 8) : Br;
7780 const uint32_t sfsh = Bc * sfshstride * acctype;
7781
7782 const uint32_t kshstride = hsk_pad / 4 + 2;
7783 const uint32_t ksh = Bc * kshstride * f16vec4;
7784
7785 const uint32_t slope = Br * sizeof(float);
7786
7787 const uint32_t total_size = tmpsh + tmpshv4 + Qf + sfsh + ksh + slope;
7788 const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
7789
7790 VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", f32acc=" << f32acc << ", total_size=" << total_size << ", supported=" << supported);
7791
7792 return supported;
7793}
7794
7795static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * q, const ggml_tensor * k, const ggml_tensor * v, const ggml_tensor * mask, const ggml_tensor * sinks, ggml_tensor * dst) {
7796 VK_LOG_DEBUG("ggml_vk_flash_attn((" << q << ", name=" << q->name << ", type=" << q->type << ", ne0=" << q->ne[0] << ", ne1=" << q->ne[1] << ", ne2=" << q->ne[2] << ", ne3=" << q->ne[3] << ", nb0=" << q->nb[0] << ", nb1=" << q->nb[1] << ", nb2=" << q->nb[2] << ", nb3=" << q->nb[3];
7797 std::cerr << "), (" << k << ", name=" << k->name << ", type=" << k->type << ", ne0=" << k->ne[0] << ", ne1=" << k->ne[1] << ", ne2=" << k->ne[2] << ", ne3=" << k->ne[3] << ", nb0=" << k->nb[0] << ", nb1=" << k->nb[1] << ", nb2=" << k->nb[2] << ", nb3=" << k->nb[3];
7798 std::cerr << "), (" << v << ", name=" << v->name << ", type=" << v->type << ", ne0=" << v->ne[0] << ", ne1=" << v->ne[1] << ", ne2=" << v->ne[2] << ", ne3=" << v->ne[3] << ", nb0=" << v->nb[0] << ", nb1=" << v->nb[1] << ", nb2=" << v->nb[2] << ", nb3=" << v->nb[3];
7799 std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
7800 if (sinks) {
7801 std::cerr << "), (" << sinks << ", name=" << sinks->name << ", type=" << sinks->type << ", ne0=" << sinks->ne[0] << ", ne1=" << sinks->ne[1] << ", ne2=" << sinks->ne[2] << ", ne3=" << sinks->ne[3] << ", nb0=" << sinks->nb[0] << ", nb1=" << sinks->nb[1] << ", nb2=" << sinks->nb[2] << ", nb3=" << sinks->nb[3];
7802 }
7803 std::cerr << "))");
7804
7805 GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
7806 GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
7807 GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
7808 GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
7809 GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
7810 GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
7811 GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
7812 GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
7813
7814 const uint32_t nem1 = mask ? mask->ne[1] : 0;
7815 const uint32_t nem2 = mask ? mask->ne[2] : 0;
7816 const uint32_t nem3 = mask ? mask->ne[3] : 0;
7817
7818 const uint32_t HSK = nek0;
7819 const uint32_t HSV = nev0;
7820 uint32_t N = neq1;
7821 const uint32_t KV = nek1;
7822
7823 GGML_ASSERT(ne0 == HSV);
7824 GGML_ASSERT(ne2 == N);
7825
7826 // input tensor rows must be contiguous
7827 GGML_ASSERT(nbq0 == ggml_type_size(q->type));
7828 GGML_ASSERT(nbk0 == ggml_type_size(k->type));
7829 GGML_ASSERT(nbv0 == ggml_type_size(v->type));
7830
7831 GGML_ASSERT(neq0 == HSK);
7832
7833 GGML_ASSERT(neq1 == N);
7834
7835 GGML_ASSERT(nev1 == nek1);
7836
7837 // dst cannot be transposed or permuted
7838 GGML_ASSERT(nb0 == sizeof(float));
7839 GGML_ASSERT(nb0 <= nb1);
7840 GGML_ASSERT(nb1 <= nb2);
7841 GGML_ASSERT(nb2 <= nb3);
7842
7843 assert(dst->type == GGML_TYPE_F32);
7844 assert(q->type == GGML_TYPE_F32);
7845 assert(k->type == v->type);
7846
7847 FaCodePath path = ctx->device->coopmat2 ? FA_COOPMAT2 :
7848 ctx->device->coopmat1_fa_support ? FA_COOPMAT1 : FA_SCALAR;
7849
7850 if (path == FA_COOPMAT1) {
7851 const bool coopmat_shape_supported = (dst->op_params[3] == GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f32acc) ||
7852 (dst->op_params[3] != GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f16acc);
7853
7854 const bool coopmat_shmem_supported = ggml_vk_flash_attn_coopmat_shmem_support(device: ctx->device, hsk: HSK, hsv: HSV, f32acc: dst->op_params[3] == GGML_PREC_F32);
7855
7856 if (!coopmat_shape_supported || !coopmat_shmem_supported) {
7857 path = FA_SCALAR;
7858 }
7859 }
7860
7861 uint32_t gqa_ratio = 1;
7862 uint32_t qk_ratio = neq2 / nek2;
7863 uint32_t workgroups_x = (uint32_t)neq1;
7864 uint32_t workgroups_y = (uint32_t)neq2;
7865 uint32_t workgroups_z = (uint32_t)neq3;
7866
7867 // For scalar/coopmat1 FA, we can use the "large" size to accommodate qga.
7868 // For coopmat2 FA, we always use the small size (which is still pretty large for gqa).
7869 uint32_t max_gqa;
7870 switch (path) {
7871 case FA_SCALAR:
7872 case FA_COOPMAT1:
7873 // We may switch from coopmat1 to scalar, so use the scalar limit for both
7874 max_gqa = get_fa_scalar_num_large_rows(hsv: HSV);
7875 break;
7876 case FA_COOPMAT2:
7877 max_gqa = get_fa_num_small_rows(path: FA_COOPMAT2);
7878 break;
7879 default:
7880 GGML_ASSERT(0);
7881 }
7882
7883 if (N == 1 && qk_ratio > 1 && qk_ratio <= max_gqa &&
7884 qk_ratio * nek2 == neq2 && nek2 == nev2 && nem2 <= 1) {
7885 // grouped query attention - make the N dimension equal to gqa_ratio, reduce
7886 // workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1
7887 // and change addressing calculations to index Q's dimension 2.
7888 gqa_ratio = qk_ratio;
7889 N = gqa_ratio;
7890 workgroups_y /= N;
7891 }
7892
7893 bool small_rows = N <= get_fa_num_small_rows(path);
7894
7895 // coopmat1 does not actually support "small rows" (it needs 16 rows).
7896 // So use scalar instead.
7897 if (small_rows && path == FA_COOPMAT1) {
7898 path = FA_SCALAR;
7899 }
7900
7901 // scalar is faster than coopmat2 when N==1
7902 if (N == 1 && path == FA_COOPMAT2) {
7903 path = FA_SCALAR;
7904 }
7905
7906 // with large hsk/hsv, scalar path may need to use small_rows to fit in shared memory
7907 if (path == FA_SCALAR &&
7908 !ggml_vk_flash_attn_scalar_shmem_support(device: ctx->device, hsk: HSK, hsv: HSV)) {
7909 small_rows = true;
7910 }
7911
7912 const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(type: q->type));
7913 uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(type: k->type));
7914 uint32_t v_stride = (uint32_t)(nbv1 / ggml_type_size(type: v->type));
7915
7916 // For F32, the shader treats it as a block of size 4 (for vec4 loads)
7917 if (k->type == GGML_TYPE_F32) {
7918 k_stride /= 4;
7919 }
7920 if (v->type == GGML_TYPE_F32) {
7921 v_stride /= 4;
7922 }
7923
7924 uint32_t alignment = fa_align(path, hsk: HSK, hsv: HSV, type: k->type, small_rows);
7925 bool aligned = (KV % alignment) == 0 &&
7926 // the "aligned" shader variant will forcibly align strides, for performance
7927 (q_stride & 7) == 0 && (k_stride & 7) == 0 && (v_stride & 7) == 0;
7928
7929 // Need to use the coopmat2 variant that clamps loads when HSK/HSV aren't sufficiently aligned.
7930 if (((HSK | HSV) % 16) != 0 && path == FA_COOPMAT2) {
7931 aligned = false;
7932 }
7933
7934 bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32;
7935
7936 vk_fa_pipeline_state fa_pipeline_state(HSK, HSV, small_rows, path, aligned, f32acc);
7937
7938 vk_pipeline pipeline = nullptr;
7939
7940 {
7941 std::lock_guard<std::recursive_mutex> guard(ctx->device->mutex);
7942 auto &pipelines = ctx->device->pipeline_flash_attn_f32_f16[k->type];
7943 auto it = pipelines.find(x: fa_pipeline_state);
7944 if (it != pipelines.end()) {
7945 pipeline = it->second;
7946 } else {
7947 pipelines[fa_pipeline_state] = pipeline = std::make_shared<vk_pipeline_struct>();
7948 }
7949 }
7950
7951 assert(pipeline);
7952
7953 uint32_t split_kv = KV;
7954 uint32_t split_k = 1;
7955
7956 // Use a placeholder core count if one isn't available. split_k is a big help for perf.
7957 const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count : 16;
7958
7959 // Try to use split_k when KV is large enough to be worth the overhead
7960 if (workgroups_x == 1 && shader_core_count > 0) {
7961 // Try to run two workgroups per SM.
7962 split_k = shader_core_count * 2 / (workgroups_y * workgroups_z);
7963 if (split_k > 1) {
7964 // Try to evenly split KV into split_k chunks, but it needs to be a multiple
7965 // of "align", so recompute split_k based on that.
7966 split_kv = ROUNDUP_POW2(std::max(1u, KV / split_k), alignment);
7967 split_k = CEIL_DIV(KV, split_kv);
7968 workgroups_x = split_k;
7969 }
7970 }
7971
7972 // Reserve space for split_k temporaries. For each split x batch, we need to store the O matrix (D x ne1)
7973 // and the per-row m and L values (ne1 rows). We store all the matrices first, followed by the rows.
7974 const uint64_t split_k_size = split_k > 1 ? (HSV * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k * ne3 : 0;
7975 if (split_k_size > ctx->device->properties.limits.maxStorageBufferRange) {
7976 GGML_ABORT("Requested preallocation size is too large");
7977 }
7978 if (ctx->prealloc_size_split_k < split_k_size) {
7979 ctx->prealloc_size_split_k = split_k_size;
7980 ggml_vk_preallocate_buffers(ctx, subctx);
7981 }
7982
7983 {
7984 // Request descriptor sets
7985 ggml_pipeline_request_descriptor_sets(ctx, pipeline, n: 1);
7986 if (split_k > 1) {
7987 ggml_pipeline_request_descriptor_sets(ctx, pipeline&: ctx->device->pipeline_flash_attn_split_k_reduce, n: 1);
7988 }
7989 }
7990
7991 float scale = 1.0f;
7992 float max_bias = 0.0f;
7993 float logit_softcap = 0.0f;
7994
7995 memcpy(dest: &scale, src: (const float *) dst->op_params + 0, n: sizeof(float));
7996 memcpy(dest: &max_bias, src: (const float *) dst->op_params + 1, n: sizeof(float));
7997 memcpy(dest: &logit_softcap, src: (const float *) dst->op_params + 2, n: sizeof(float));
7998
7999 if (logit_softcap != 0) {
8000 scale /= logit_softcap;
8001 }
8002
8003 const uint32_t n_head_kv = neq2;
8004 const uint32_t n_head_log2 = 1u << (uint32_t) floorf(x: log2f(x: (float) n_head_kv));
8005 const float m0 = powf(x: 2.0f, y: -(max_bias ) / n_head_log2);
8006 const float m1 = powf(x: 2.0f, y: -(max_bias / 2.0f) / n_head_log2);
8007
8008 vk_subbuffer q_buf = ggml_vk_tensor_subbuffer(ctx, tensor: q);
8009 vk_subbuffer k_buf = ggml_vk_tensor_subbuffer(ctx, tensor: k);
8010 vk_subbuffer v_buf = ggml_vk_tensor_subbuffer(ctx, tensor: v);
8011 vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, tensor: dst);
8012 vk_subbuffer mask_buf = mask ? ggml_vk_tensor_subbuffer(ctx, tensor: mask) : q_buf;
8013 vk_subbuffer sinks_buf = sinks ? ggml_vk_tensor_subbuffer(ctx, tensor: sinks) : q_buf;
8014
8015 uint32_t mask_n_head_log2 = ((sinks != nullptr) << 24) | ((mask != nullptr) << 16) | n_head_log2;
8016
8017 const vk_flash_attn_push_constants pc = { .N: N, .KV: KV,
8018 .ne1: (uint32_t)ne1, .ne2: (uint32_t)ne2, .ne3: (uint32_t)ne3,
8019 .neq2: (uint32_t)neq2, .neq3: (uint32_t)neq3,
8020 .nek2: (uint32_t)nek2, .nek3: (uint32_t)nek3,
8021 .nev2: (uint32_t)nev2, .nev3: (uint32_t)nev3,
8022 .nem1: nem1, .nem2: nem2, .nem3: nem3,
8023 .nb01: q_stride, .nb02: (uint32_t)nbq2, .nb03: (uint32_t)nbq3,
8024 .nb11: k_stride, .nb12: (uint32_t)nbk2, .nb13: (uint32_t)nbk3,
8025 .nb21: v_stride, .nb22: (uint32_t)nbv2, .nb23: (uint32_t)nbv3,
8026 .scale: scale, .max_bias: max_bias, .logit_softcap: logit_softcap,
8027 .mask_n_head_log2: mask_n_head_log2, .m0: m0, .m1: m1,
8028 .gqa_ratio: gqa_ratio, .split_kv: split_kv, .k_num: split_k };
8029
8030 if (split_k > 1) {
8031 if (ctx->prealloc_split_k_need_sync) {
8032 ggml_vk_sync_buffers(ctx, subctx);
8033 }
8034
8035 vk_subbuffer split_k_buf = ggml_vk_subbuffer(ctx, buf: ctx->prealloc_split_k, offset: 0);
8036 ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
8037 descriptor_buffer_infos: {q_buf, k_buf, v_buf, mask_buf, sinks_buf, split_k_buf},
8038 // We only use split_k when group query attention is enabled, which means
8039 // there's no more than one tile of rows (i.e. workgroups_x would have been
8040 // one). We reuse workgroups_x to mean the number of splits, so we need to
8041 // cancel out the divide by wg_denoms[0].
8042 push_constants: pc, elements: { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z });
8043
8044 ggml_vk_sync_buffers(ctx, subctx);
8045 const std::array<uint32_t, 5> pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne3, split_k, (sinks != nullptr) };
8046 ggml_vk_dispatch_pipeline(ctx, subctx, pipeline&: ctx->device->pipeline_flash_attn_split_k_reduce,
8047 descriptor_buffer_infos: {split_k_buf, sinks_buf, dst_buf},
8048 push_constants: pc2, elements: { (uint32_t)ne1, HSV, (uint32_t)ne3 });
8049 ctx->prealloc_split_k_need_sync = true;
8050 } else {
8051 ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
8052 descriptor_buffer_infos: {q_buf, k_buf, v_buf, mask_buf, sinks_buf, dst_buf},
8053 push_constants: pc, elements: { workgroups_x, workgroups_y, workgroups_z });
8054 }
8055}
8056
8057static std::array<uint32_t, 3> ggml_vk_get_conv_elements(const ggml_tensor *dst) {
8058 const ggml_tensor *src0 = dst->src[0];
8059 const ggml_tensor *src1 = dst->src[1];
8060
8061 // src0 - kernel: [KW, KH, Cin, Cout]
8062 // src1 - input: [W, H, Cin, N]
8063 // dst - result: [OW, OH, Cout, N]
8064
8065 // Copied from ggml.c: int64_t ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p, int d)
8066 auto calc_conv_output_size = [](int64_t ins, int64_t ks, int s, int p, int d) -> int64_t {
8067 return (ins + 2 * p - d * (ks - 1) - 1) / s + 1;
8068 };
8069 // parallelize in {OW/BS_K, OH/BS_NPQ, 1}
8070 int64_t W = src1->ne[0];
8071 int64_t H = src1->ne[1];
8072 int64_t KW = src0->ne[0];
8073 int64_t KH = src0->ne[1];
8074 int64_t Cout = src0->ne[3];
8075 int64_t N = src1->ne[3];
8076 int64_t OH = calc_conv_output_size(H, KH, dst->op_params[1], dst->op_params[3], dst->op_params[5]);
8077 int64_t OW = calc_conv_output_size(W, KW, dst->op_params[0], dst->op_params[2], dst->op_params[4]);
8078 int64_t NPQ = N * OW * OH;
8079
8080 // Tile output matrix to (K/NB_K, NPQ/NB_NPQ, 1) workgroups
8081 std::array<uint32_t, 3> elements = { static_cast<uint32_t>(Cout), static_cast<uint32_t>(NPQ), 1 };
8082 return elements;
8083}
8084
8085static std::array<uint32_t, 3> ggml_vk_get_conv_transpose_2d_elements(const ggml_tensor *dst) {
8086 const ggml_tensor *src0 = dst->src[0];
8087 const ggml_tensor *src1 = dst->src[1];
8088
8089 // src0 - kernel: [KW, KH, Cout, Cin]
8090 // src1 - input: [W, H, Cin, N]
8091 // dst - result: [OW, OH, Cout, N]
8092
8093 auto calc_conv_output_size = [](int64_t ins, int64_t ks, int s, int p, int d) -> int64_t {
8094 return (ins - 1) * s - 2 * p + (ks - 1) * d + 1;
8095 };
8096 // parallelize in {OW/BS_K, OH/BS_NPQ, 1}
8097 int64_t W = src1->ne[0];
8098 int64_t H = src1->ne[1];
8099 int64_t KW = src0->ne[0];
8100 int64_t KH = src0->ne[1];
8101 int64_t Cout = src0->ne[2];
8102 int64_t N = src1->ne[3];
8103 int64_t OH = calc_conv_output_size(H, KH, dst->op_params[0], 0, 1);
8104 int64_t OW = calc_conv_output_size(W, KW, dst->op_params[0], 0, 1);
8105 int64_t NPQ = N * OW * OH;
8106
8107 // Tile output matrix to (K/NB_K, NPQ/NB_NPQ, 1) workgroups
8108 std::array<uint32_t, 3> elements = { static_cast<uint32_t>(Cout), static_cast<uint32_t>(NPQ), 1 };
8109 return elements;
8110}
8111
8112static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * dst, ggml_op op) {
8113 switch (op) {
8114 case GGML_OP_GET_ROWS:
8115 GGML_ASSERT(src1->type == GGML_TYPE_I32);
8116 if (dst->type == GGML_TYPE_F16) {
8117 return ctx->device->pipeline_get_rows[src0->type];
8118 }
8119 if (dst->type == GGML_TYPE_F32) {
8120 return ctx->device->pipeline_get_rows_f32[src0->type];
8121 }
8122 return nullptr;
8123 case GGML_OP_ACC:
8124 if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
8125 return ctx->device->pipeline_acc_f32;
8126 }
8127 return nullptr;
8128 case GGML_OP_ADD:
8129 case GGML_OP_SUB:
8130 case GGML_OP_MUL:
8131 case GGML_OP_DIV:
8132 if ((src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) ||
8133 (src1->type != GGML_TYPE_F32 && src1->type != GGML_TYPE_F16) ||
8134 (dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16)) {
8135 return nullptr;
8136 }
8137 switch (op) {
8138 case GGML_OP_ADD:
8139 {
8140 if (ctx->num_additional_fused_ops > 0) {
8141 if (ctx->do_add_rms_partials) {
8142 return ctx->device->pipeline_multi_add_rms[ctx->num_additional_fused_ops];
8143 } else {
8144 return ctx->device->pipeline_multi_add[ctx->num_additional_fused_ops];
8145 }
8146 }
8147 if (ctx->do_add_rms_partials) {
8148 auto pipelines = ggml_are_same_shape(t0: src0, t1: src1) ? ctx->device->pipeline_add_rms_norepeat : ctx->device->pipeline_add_rms;
8149 return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
8150 } else {
8151 auto pipelines = ggml_are_same_shape(t0: src0, t1: src1) ? ctx->device->pipeline_add_norepeat : ctx->device->pipeline_add;
8152 return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
8153 }
8154 }
8155 case GGML_OP_SUB:
8156 {
8157 auto pipelines = ggml_are_same_shape(t0: src0, t1: src1) ? ctx->device->pipeline_sub_norepeat : ctx->device->pipeline_sub;
8158 return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
8159 }
8160 case GGML_OP_MUL:
8161 {
8162 auto pipelines = ggml_are_same_shape(t0: src0, t1: src1) ? ctx->device->pipeline_mul_norepeat : ctx->device->pipeline_mul;
8163 return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
8164 }
8165 case GGML_OP_DIV:
8166 {
8167 auto pipelines = ggml_are_same_shape(t0: src0, t1: src1) ? ctx->device->pipeline_div_norepeat : ctx->device->pipeline_div;
8168 return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
8169 }
8170 default:
8171 break;
8172 }
8173 return nullptr;
8174 case GGML_OP_ADD_ID:
8175 if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && src2->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_F32) {
8176 return ctx->device->pipeline_add_id_f32;
8177 }
8178 return nullptr;
8179 case GGML_OP_CONCAT:
8180 if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
8181 return ctx->device->pipeline_concat_f32;
8182 }
8183 if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
8184 return ctx->device->pipeline_concat_f16;
8185 }
8186 if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I32) {
8187 return ctx->device->pipeline_concat_i32;
8188 }
8189 return nullptr;
8190 case GGML_OP_UPSCALE:
8191 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
8192 ggml_scale_mode mode = (ggml_scale_mode)(ggml_get_op_params_i32(tensor: dst, i: 0) & 0xFF);
8193 switch (mode) {
8194 case GGML_SCALE_MODE_NEAREST:
8195 return ctx->device->pipeline_upscale_nearest_f32;
8196 case GGML_SCALE_MODE_BILINEAR:
8197 return ctx->device->pipeline_upscale_bilinear_f32;
8198 default:
8199 return nullptr;
8200 }
8201 }
8202 return nullptr;
8203 case GGML_OP_SCALE:
8204 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
8205 return ctx->device->pipeline_scale_f32;
8206 }
8207 return nullptr;
8208 case GGML_OP_SQR:
8209 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
8210 return ctx->device->pipeline_sqr_f32;
8211 }
8212 return nullptr;
8213 case GGML_OP_SQRT:
8214 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
8215 return ctx->device->pipeline_sqrt_f32;
8216 }
8217 return nullptr;
8218 case GGML_OP_SIN:
8219 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
8220 return ctx->device->pipeline_sin_f32;
8221 }
8222 return nullptr;
8223 case GGML_OP_COS:
8224 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
8225 return ctx->device->pipeline_cos_f32;
8226 }
8227 return nullptr;
8228 case GGML_OP_CLAMP:
8229 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
8230 return ctx->device->pipeline_clamp_f32;
8231 }
8232 return nullptr;
8233 case GGML_OP_PAD:
8234 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
8235 return ctx->device->pipeline_pad_f32;
8236 }
8237 return nullptr;
8238 case GGML_OP_ROLL:
8239 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
8240 return ctx->device->pipeline_roll_f32;
8241 }
8242 return nullptr;
8243 case GGML_OP_REPEAT:
8244 if (ggml_type_size(type: src0->type) == sizeof(float) && ggml_type_size(type: dst->type) == sizeof(float)) {
8245 return ctx->device->pipeline_repeat_f32;
8246 }
8247 return nullptr;
8248 case GGML_OP_REPEAT_BACK:
8249 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
8250 return ctx->device->pipeline_repeat_back_f32;
8251 }
8252 return nullptr;
8253 case GGML_OP_CPY:
8254 case GGML_OP_CONT:
8255 case GGML_OP_DUP:
8256 return ggml_vk_get_cpy_pipeline(ctx, src: src0, dst, to: dst->type);
8257 case GGML_OP_SET_ROWS:
8258 if (src1->type == GGML_TYPE_I64) {
8259 return ctx->device->pipeline_set_rows_i64[dst->type];
8260 } else {
8261 return ctx->device->pipeline_set_rows_i32[dst->type];
8262 }
8263 case GGML_OP_SILU_BACK:
8264 if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
8265 return ctx->device->pipeline_silu_back_f32;
8266 }
8267 return nullptr;
8268 case GGML_OP_NORM:
8269 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
8270 return ctx->device->pipeline_norm_f32;
8271 }
8272 return nullptr;
8273 case GGML_OP_GROUP_NORM:
8274 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
8275 return ctx->device->pipeline_group_norm_f32;
8276 }
8277 return nullptr;
8278 case GGML_OP_RMS_NORM:
8279 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
8280 if (ctx->do_add_rms_partials) {
8281 return ctx->num_additional_fused_ops > 0 ? ctx->device->pipeline_rms_norm_mul_partials_f32 : ctx->device->pipeline_rms_norm_partials_f32;
8282 } else {
8283 return ctx->num_additional_fused_ops > 0 ? ctx->device->pipeline_rms_norm_mul_f32 : ctx->device->pipeline_rms_norm_f32;
8284 }
8285 }
8286 return nullptr;
8287 case GGML_OP_RMS_NORM_BACK:
8288 if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
8289 return ctx->device->pipeline_rms_norm_back_f32;
8290 }
8291 return nullptr;
8292 case GGML_OP_L2_NORM:
8293 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
8294 return ctx->device->pipeline_l2_norm_f32;
8295 }
8296 return nullptr;
8297 case GGML_OP_UNARY:
8298 if ((src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) ||
8299 (dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16) ||
8300 (src0->type != dst->type)) {
8301 return nullptr;
8302 }
8303
8304 switch (ggml_get_unary_op(tensor: dst)) {
8305 case GGML_UNARY_OP_EXP:
8306 return ctx->device->pipeline_exp[dst->type == GGML_TYPE_F16];
8307 case GGML_UNARY_OP_SILU:
8308 return ctx->device->pipeline_silu[dst->type == GGML_TYPE_F16];
8309 case GGML_UNARY_OP_GELU:
8310 return ctx->device->pipeline_gelu[dst->type == GGML_TYPE_F16];
8311 case GGML_UNARY_OP_GELU_ERF:
8312 return ctx->device->pipeline_gelu_erf[dst->type == GGML_TYPE_F16];
8313 case GGML_UNARY_OP_GELU_QUICK:
8314 return ctx->device->pipeline_gelu_quick[dst->type == GGML_TYPE_F16];
8315 case GGML_UNARY_OP_RELU:
8316 return ctx->device->pipeline_relu[dst->type == GGML_TYPE_F16];
8317 case GGML_UNARY_OP_TANH:
8318 return ctx->device->pipeline_tanh[dst->type == GGML_TYPE_F16];
8319 case GGML_UNARY_OP_SIGMOID:
8320 return ctx->device->pipeline_sigmoid[dst->type == GGML_TYPE_F16];
8321 case GGML_UNARY_OP_HARDSIGMOID:
8322 return ctx->device->pipeline_hardsigmoid[dst->type == GGML_TYPE_F16];
8323 case GGML_UNARY_OP_HARDSWISH:
8324 return ctx->device->pipeline_hardswish[dst->type == GGML_TYPE_F16];
8325 default:
8326 break;
8327 }
8328 return nullptr;
8329 case GGML_OP_GLU:
8330 if ((src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) ||
8331 (dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16) ||
8332 (src0->type != dst->type)) {
8333 return nullptr;
8334 }
8335
8336 switch (ggml_get_glu_op(tensor: dst)) {
8337 case GGML_GLU_OP_GEGLU:
8338 return ctx->device->pipeline_geglu[dst->type == GGML_TYPE_F16];
8339 case GGML_GLU_OP_REGLU:
8340 return ctx->device->pipeline_reglu[dst->type == GGML_TYPE_F16];
8341 case GGML_GLU_OP_SWIGLU:
8342 return ctx->device->pipeline_swiglu[dst->type == GGML_TYPE_F16];
8343 case GGML_GLU_OP_SWIGLU_OAI:
8344 return ctx->device->pipeline_swiglu_oai[dst->type == GGML_TYPE_F16];
8345 case GGML_GLU_OP_GEGLU_ERF:
8346 return ctx->device->pipeline_geglu_erf[dst->type == GGML_TYPE_F16];
8347 case GGML_GLU_OP_GEGLU_QUICK:
8348 return ctx->device->pipeline_geglu_quick[dst->type == GGML_TYPE_F16];
8349 default:
8350 break;
8351 }
8352 return nullptr;
8353 case GGML_OP_DIAG_MASK_INF:
8354 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
8355 return ctx->device->pipeline_diag_mask_inf_f32;
8356 }
8357 return nullptr;
8358 case GGML_OP_SOFT_MAX:
8359 GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);
8360 GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32);
8361
8362 if (ctx->num_additional_fused_ops) {
8363 uint32_t idx = (uint32_t)ceilf(x: log2f(x: float(dst->ne[0])));
8364 GGML_ASSERT(idx < num_topk_moe_pipelines);
8365 topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(num: ctx->num_additional_fused_ops);
8366 return ctx->device->pipeline_topk_moe[idx][mode];
8367 }
8368
8369 if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
8370 return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_wg512 : ctx->device->pipeline_soft_max_f32;
8371 }
8372 if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
8373 return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_f16_wg512 : ctx->device->pipeline_soft_max_f32_f16;
8374 }
8375 return nullptr;
8376 case GGML_OP_SOFT_MAX_BACK:
8377 if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
8378 return ctx->device->pipeline_soft_max_back_f32;
8379 }
8380 return nullptr;
8381 case GGML_OP_ROPE:
8382 case GGML_OP_ROPE_BACK:
8383 {
8384 const ggml_tensor *rope = ctx->num_additional_fused_ops == 2 ? dst->src[0]->src[0] : dst;
8385 const int mode = ((const int32_t *) rope->op_params)[2];
8386 const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
8387 const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
8388 const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
8389
8390 if (is_neox) {
8391 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
8392 return ctx->device->pipeline_rope_neox_f32;
8393 }
8394 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
8395 return ctx->device->pipeline_rope_neox_f32_f16;
8396 }
8397 if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
8398 return ctx->device->pipeline_rope_neox_f16;
8399 }
8400 } else if (is_mrope && !is_vision) {
8401 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
8402 return ctx->device->pipeline_rope_multi_f32;
8403 }
8404 if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
8405 return ctx->device->pipeline_rope_multi_f16;
8406 }
8407 } else if (is_vision) {
8408 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
8409 return ctx->device->pipeline_rope_vision_f32;
8410 }
8411 if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
8412 return ctx->device->pipeline_rope_vision_f16;
8413 }
8414 } else {
8415 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
8416 return ctx->device->pipeline_rope_norm_f32;
8417 }
8418 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
8419 return ctx->device->pipeline_rope_norm_f32_f16;
8420 }
8421 if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
8422 return ctx->device->pipeline_rope_norm_f16;
8423 }
8424 }
8425 return nullptr;
8426 }
8427 case GGML_OP_ARGSORT:
8428 if (ctx->num_additional_fused_ops) {
8429 uint32_t idx = (uint32_t)ceilf(x: log2f(x: float(dst->ne[0])));
8430 GGML_ASSERT(idx < num_topk_moe_pipelines);
8431 topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(num: ctx->num_additional_fused_ops);
8432 return ctx->device->pipeline_topk_moe[idx][mode];
8433 }
8434
8435 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) {
8436 uint32_t idx = (uint32_t)ceilf(x: log2f(x: float(dst->ne[0])));
8437 return ctx->device->pipeline_argsort_f32[idx];
8438 }
8439 return nullptr;
8440 case GGML_OP_SUM:
8441 case GGML_OP_SUM_ROWS:
8442 case GGML_OP_MEAN:
8443 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
8444 return ctx->device->pipeline_sum_rows_f32;
8445 }
8446 return nullptr;
8447 case GGML_OP_ARGMAX:
8448 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) {
8449 return ctx->device->pipeline_argmax_f32;
8450 }
8451 return nullptr;
8452 case GGML_OP_COUNT_EQUAL:
8453 if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I64) {
8454 return ctx->device->pipeline_count_equal_i32;
8455 }
8456 return nullptr;
8457 case GGML_OP_IM2COL:
8458 if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
8459 return ctx->device->pipeline_im2col_f32;
8460 }
8461 if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
8462 return ctx->device->pipeline_im2col_f32_f16;
8463 }
8464 return nullptr;
8465 case GGML_OP_IM2COL_3D:
8466 if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
8467 return ctx->device->pipeline_im2col_3d_f32;
8468 }
8469 if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
8470 return ctx->device->pipeline_im2col_3d_f32_f16;
8471 }
8472 return nullptr;
8473 case GGML_OP_TIMESTEP_EMBEDDING:
8474 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
8475 return ctx->device->pipeline_timestep_embedding_f32;
8476 }
8477 return nullptr;
8478 case GGML_OP_CONV_TRANSPOSE_1D:
8479 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
8480 return ctx->device->pipeline_conv_transpose_1d_f32;
8481 }
8482 return nullptr;
8483 case GGML_OP_POOL_2D:
8484 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
8485 return ctx->device->pipeline_pool2d_f32;
8486 }
8487 return nullptr;
8488 case GGML_OP_RWKV_WKV6:
8489 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
8490 return ctx->device->pipeline_rwkv_wkv6_f32;
8491 }
8492 return nullptr;
8493 case GGML_OP_RWKV_WKV7:
8494 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
8495 return ctx->device->pipeline_rwkv_wkv7_f32;
8496 }
8497 return nullptr;
8498 case GGML_OP_SSM_SCAN:
8499 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
8500 const uint32_t d_state = src0->ne[0];
8501 if (d_state == 128) {
8502 return ctx->device->pipeline_ssm_scan_f32_d128;
8503 } else if (d_state == 256) {
8504 return ctx->device->pipeline_ssm_scan_f32_d256;
8505 }
8506 }
8507 return nullptr;
8508 case GGML_OP_SSM_CONV:
8509 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
8510 return ctx->device->pipeline_ssm_conv_f32;
8511 }
8512 return nullptr;
8513 case GGML_OP_OPT_STEP_ADAMW:
8514 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
8515 return ctx->device->pipeline_opt_step_adamw_f32;
8516 }
8517 return nullptr;
8518 case GGML_OP_OPT_STEP_SGD:
8519 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
8520 return ctx->device->pipeline_opt_step_sgd_f32;
8521 }
8522 return nullptr;
8523 case GGML_OP_LEAKY_RELU:
8524 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
8525 return ctx->device->pipeline_leaky_relu_f32;
8526 }
8527 return nullptr;
8528 case GGML_OP_CONV_2D:
8529 case GGML_OP_CONV_TRANSPOSE_2D:
8530 if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 &&
8531 ggml_is_contiguous(tensor: src0) && ggml_is_contiguous(tensor: src1) && ggml_is_contiguous(tensor: dst)) {
8532 std::array<uint32_t, 3> elements;
8533 if (op == GGML_OP_CONV_2D) elements = ggml_vk_get_conv_elements(dst);
8534 else if (op == GGML_OP_CONV_TRANSPOSE_2D) elements = ggml_vk_get_conv_transpose_2d_elements(dst);
8535 vk_conv_shapes shape;
8536
8537 uint32_t tiles[CONV_SHAPE_COUNT];
8538 for (uint32_t i = 0; i < CONV_SHAPE_COUNT; ++i) {
8539 tiles[i] = CEIL_DIV(elements[0], ctx->device->pipeline_conv2d_f32[i]->wg_denoms[0]) * CEIL_DIV(elements[1], ctx->device->pipeline_conv2d_f32[i]->wg_denoms[1]);
8540 }
8541
8542 // We can't query number of shader cores on Intel, use 32 as a placeholder
8543 // so small convolutions will still choose a smaller tile.
8544 const uint32_t shader_core_count = ctx->device->shader_core_count > 0 ? ctx->device->shader_core_count : 32;
8545
8546 if (elements[0] > 64 && tiles[CONV_SHAPE_128x128] >= shader_core_count * 2) {
8547 shape = CONV_SHAPE_128x128;
8548 } else if (elements[0] <= 32 && tiles[CONV_SHAPE_32x256] >= shader_core_count * 2) {
8549 shape = CONV_SHAPE_32x256;
8550 } else {
8551 shape = CONV_SHAPE_64x32;
8552 }
8553
8554 if (op == GGML_OP_CONV_2D) {
8555 if (src0->type == GGML_TYPE_F32) {
8556 return ctx->device->pipeline_conv2d_f32[shape];
8557 } else if (src0->type == GGML_TYPE_F16) {
8558 return ctx->device->pipeline_conv2d_f16_f32[shape];
8559 }
8560 } else if (op == GGML_OP_CONV_TRANSPOSE_2D) {
8561 if (src0->type == GGML_TYPE_F32) {
8562 return ctx->device->pipeline_conv_transpose_2d_f32[shape];
8563 } else if (src0->type == GGML_TYPE_F16) {
8564 return ctx->device->pipeline_conv_transpose_2d_f16_f32[shape];
8565 }
8566 }
8567 }
8568 return nullptr;
8569 case GGML_OP_CONV_2D_DW:
8570 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
8571 if (ggml_is_contiguous(tensor: src1)) {
8572 return ctx->device->pipeline_conv2d_dw_whcn_f32;
8573 } else if (ggml_is_contiguous_channels(tensor: src1)) {
8574 return ctx->device->pipeline_conv2d_dw_cwhn_f32;
8575 }
8576 } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
8577 if (ggml_is_contiguous(tensor: src1)) {
8578 return ctx->device->pipeline_conv2d_dw_whcn_f16_f32;
8579 } else if (ggml_is_contiguous_channels(tensor: src1)) {
8580 return ctx->device->pipeline_conv2d_dw_cwhn_f16_f32;
8581 }
8582 }
8583 return nullptr;
8584 default:
8585 return nullptr;
8586 }
8587
8588 GGML_UNUSED(src2);
8589}
8590
8591static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
8592 switch (op) {
8593 case GGML_OP_CPY:
8594 case GGML_OP_GET_ROWS:
8595 case GGML_OP_ADD:
8596 case GGML_OP_SUB:
8597 case GGML_OP_MUL:
8598 case GGML_OP_DIV:
8599 case GGML_OP_ADD_ID:
8600 case GGML_OP_CONCAT:
8601 case GGML_OP_UPSCALE:
8602 case GGML_OP_SQR:
8603 case GGML_OP_SQRT:
8604 case GGML_OP_SIN:
8605 case GGML_OP_COS:
8606 case GGML_OP_CLAMP:
8607 case GGML_OP_PAD:
8608 case GGML_OP_REPEAT:
8609 case GGML_OP_REPEAT_BACK:
8610 case GGML_OP_ROPE:
8611 case GGML_OP_RMS_NORM:
8612 case GGML_OP_CONV_2D_DW:
8613 case GGML_OP_IM2COL:
8614 case GGML_OP_IM2COL_3D:
8615 case GGML_OP_SET_ROWS:
8616 case GGML_OP_SUM:
8617 case GGML_OP_SUM_ROWS:
8618 case GGML_OP_MEAN:
8619 return true;
8620 default:
8621 return false;
8622 }
8623}
8624
8625static uint32_t get_misalign_bytes(const ggml_backend_vk_context * ctx, const ggml_tensor * t)
8626{
8627 return ((vk_tensor_offset(tensor: t) + t->view_offs) & (ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1));;
8628}
8629
8630template <typename T> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, T &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) {
8631 GGML_UNUSED(p);
8632 GGML_UNUSED(src0);
8633 GGML_UNUSED(src1);
8634 GGML_UNUSED(src2);
8635 GGML_UNUSED(src3);
8636 GGML_UNUSED(dst);
8637 static_assert(!std::is_const<T>::value, "unexpected type");
8638 GGML_ASSERT(!src0 || get_misalign_bytes(ctx, src0) == 0);
8639 GGML_ASSERT(!src1 || get_misalign_bytes(ctx, src1) == 0);
8640 GGML_ASSERT(!src2 || get_misalign_bytes(ctx, src2) == 0);
8641 GGML_ASSERT(!src3 || get_misalign_bytes(ctx, src3) == 0);
8642 GGML_ASSERT(!dst || get_misalign_bytes(ctx, dst) == 0);
8643}
8644
8645template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_unary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) {
8646 const uint32_t a_offset = get_misalign_bytes(ctx, t: src0) / ggml_type_size(type: src0->type);
8647 const uint32_t d_offset = get_misalign_bytes(ctx, t: dst) / ggml_type_size(type: dst->type);
8648
8649 p.misalign_offsets = (a_offset << 16) | d_offset;
8650
8651 GGML_UNUSED(src1);
8652 GGML_UNUSED(src2);
8653 GGML_UNUSED(src3);
8654}
8655
8656template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_sum_rows_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) {
8657 const uint32_t a_offset = get_misalign_bytes(ctx, t: src0) / ggml_type_size(type: src0->type);
8658 const uint32_t d_offset = get_misalign_bytes(ctx, t: dst) / ggml_type_size(type: dst->type);
8659
8660 p.misalign_offsets = (a_offset << 16) | d_offset;
8661
8662 GGML_UNUSED(src1);
8663 GGML_UNUSED(src2);
8664 GGML_UNUSED(src3);
8665}
8666
8667template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_pad_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) {
8668 const uint32_t a_offset = get_misalign_bytes(ctx, t: src0) / ggml_type_size(type: src0->type);
8669 const uint32_t d_offset = get_misalign_bytes(ctx, t: dst) / ggml_type_size(type: dst->type);
8670
8671 p.misalign_offsets = (a_offset << 16) | d_offset;
8672
8673 GGML_UNUSED(src1);
8674 GGML_UNUSED(src2);
8675 GGML_UNUSED(src3);
8676}
8677
8678template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_im2col_3d_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) {
8679 const uint32_t a_offset = get_misalign_bytes(ctx, t: src1) / ggml_type_size(type: src1->type);
8680 const uint32_t d_offset = get_misalign_bytes(ctx, t: dst) / ggml_type_size(type: dst->type);
8681
8682 p.misalign_offsets = (a_offset << 16) | d_offset;
8683
8684 GGML_UNUSED(src0);
8685 GGML_UNUSED(src2);
8686 GGML_UNUSED(src3);
8687}
8688
8689template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_binary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) {
8690 const uint32_t a_offset = get_misalign_bytes(ctx, t: src0) / ggml_type_size(type: src0->type);
8691 const uint32_t b_offset = get_misalign_bytes(ctx, t: src1) / ggml_type_size(type: src1->type);
8692 const uint32_t d_offset = get_misalign_bytes(ctx, t: dst) / ggml_type_size(type: dst->type);
8693
8694 GGML_ASSERT(dst->op != GGML_OP_GET_ROWS || (a_offset == 0 && b_offset == 0 && d_offset == 0));
8695
8696 p.misalign_offsets = (a_offset << 16) | (b_offset << 8) | d_offset;
8697
8698 GGML_UNUSED(src2);
8699 GGML_UNUSED(src3);
8700}
8701
8702template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_upscale_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) {
8703 const uint32_t a_offset = get_misalign_bytes(ctx, t: src0) / ggml_type_size(type: src0->type);
8704 const uint32_t d_offset = get_misalign_bytes(ctx, t: dst) / ggml_type_size(type: dst->type);
8705
8706 p.a_offset = a_offset;
8707 p.d_offset = d_offset;
8708
8709 GGML_UNUSED(src1);
8710 GGML_UNUSED(src2);
8711 GGML_UNUSED(src3);
8712}
8713
8714template<typename PC>
8715static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst, ggml_op op, PC&& pc) {
8716 VK_LOG_DEBUG("ggml_vk_op_f32((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
8717 if (src1 != nullptr) {
8718 std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
8719 }
8720 if (src2 != nullptr) {
8721 std::cerr << "), (" << src2 << ", name=" << src2->name << ", type=" << src2->type << ", ne0=" << src2->ne[0] << ", ne1=" << src2->ne[1] << ", ne2=" << src2->ne[2] << ", ne3=" << src2->ne[3] << ", nb0=" << src2->nb[0] << ", nb1=" << src2->nb[1] << ", nb2=" << src2->nb[2] << ", nb3=" << src2->nb[3];
8722 }
8723 if (src3 != nullptr) {
8724 std::cerr << "), (" << src3 << ", name=" << src3->name << ", type=" << src3->type << ", ne0=" << src3->ne[0] << ", ne1=" << src3->ne[1] << ", ne2=" << src3->ne[2] << ", ne3=" << src3->ne[3] << ", nb0=" << src3->nb[0] << ", nb1=" << src3->nb[1] << ", nb2=" << src3->nb[2] << ", nb3=" << src3->nb[3];
8725 }
8726 std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
8727 std::cerr << "), " << ggml_op_name(op) << ")");
8728 GGML_ASSERT(op == GGML_OP_GET_ROWS || op == GGML_OP_CPY || (!ggml_is_quantized(src0->type) && (src1 == nullptr || !ggml_is_quantized(src1->type)))); // NOLINT
8729 GGML_ASSERT(ggml_vk_op_supports_incontiguous(op) || ggml_vk_dim01_contiguous(src0)); // NOLINT
8730 GGML_ASSERT(dst->buffer != nullptr);
8731 const uint64_t ne00 = src0->ne[0];
8732 const uint64_t ne01 = src0->ne[1];
8733 const uint64_t ne02 = src0->ne[2];
8734 const uint64_t ne03 = src0->ne[3];
8735
8736 const bool use_src1 = src1 != nullptr;
8737 const uint64_t ne10 = use_src1 ? src1->ne[0] : 0;
8738 const uint64_t ne11 = use_src1 ? src1->ne[1] : 0;
8739 const uint64_t ne12 = use_src1 ? src1->ne[2] : 0;
8740 const uint64_t ne13 = use_src1 ? src1->ne[3] : 0;
8741
8742 const bool use_src2 = src2 != nullptr;
8743 const bool use_src3 = src3 != nullptr;
8744
8745 init_pushconst_fastdiv(pc);
8746
8747 vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, dst, op);
8748
8749 if (pipeline == nullptr) {
8750 std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(op) << " for " << ggml_type_name(type: src0->type);
8751 if (src1 != nullptr) {
8752 std::cerr << " and " << ggml_type_name(type: src1->type);
8753 }
8754 std::cerr << " to " << ggml_type_name(type: dst->type) << std::endl;
8755 GGML_ABORT("fatal error");
8756 }
8757
8758 ggml_pipeline_request_descriptor_sets(ctx, pipeline, n: 1);
8759
8760 const bool op_supports_incontiguous = ggml_vk_op_supports_incontiguous(op);
8761
8762 vk_subbuffer src0_buf = ggml_vk_tensor_subbuffer(ctx, tensor: src0, allow_misalign: op_supports_incontiguous);
8763 vk_subbuffer src1_buf = use_src1 ? ggml_vk_tensor_subbuffer(ctx, tensor: src1, allow_misalign: op_supports_incontiguous) : vk_subbuffer{};
8764 vk_subbuffer src2_buf = use_src2 ? ggml_vk_tensor_subbuffer(ctx, tensor: src2, allow_misalign: op_supports_incontiguous) : vk_subbuffer{};
8765 vk_subbuffer src3_buf = use_src3 ? ggml_vk_tensor_subbuffer(ctx, tensor: src3, allow_misalign: op_supports_incontiguous) : vk_subbuffer{};
8766 vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, tensor: dst, allow_misalign: op_supports_incontiguous);
8767
8768 // Compute misalignment offset for descriptors and store it in in push constants.
8769 init_pushconst_tensor_offsets(ctx, pc, src0, src1, src2, src3, dst);
8770
8771 std::array<uint32_t, 3> elements;
8772
8773 // Single call if dimension 2 is contiguous
8774 GGML_ASSERT(op_supports_incontiguous || (ggml_is_contiguous(src0) && (src1 == nullptr || ggml_is_contiguous(src1))));
8775
8776 switch (op) {
8777 case GGML_OP_NORM:
8778 case GGML_OP_RMS_NORM_BACK:
8779 case GGML_OP_L2_NORM:
8780 case GGML_OP_SOFT_MAX:
8781 case GGML_OP_SOFT_MAX_BACK:
8782 case GGML_OP_SUM_ROWS:
8783 case GGML_OP_MEAN:
8784 case GGML_OP_ARGMAX:
8785 {
8786 const uint32_t nr = ggml_nrows(tensor: src0);
8787 if (nr > 262144) {
8788 elements = { 512, 512, CEIL_DIV(nr, 262144) };
8789 } else if (nr > 512) {
8790 elements = { 512, CEIL_DIV(nr, 512), 1 };
8791 } else {
8792 elements = { nr, 1, 1 };
8793 }
8794 } break;
8795 case GGML_OP_RMS_NORM:
8796 if (ctx->do_add_rms_partials) {
8797 // Run one element per thread, 128 threads per workgroup
8798 elements = { (uint32_t)CEIL_DIV(ne00, 128), 1, 1 };
8799 } else {
8800 elements = { (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne03 };
8801 }
8802 break;
8803
8804 case GGML_OP_SUM:
8805 // We use GGML_OP_SUM_ROWS with 1 row.
8806 elements = { 1, 1, 1 };
8807 break;
8808 case GGML_OP_GROUP_NORM:
8809 {
8810 const uint32_t num_groups = dst->op_params[0];
8811 elements = { num_groups * (uint32_t)src0->ne[3], 1, 1 };
8812 } break;
8813 case GGML_OP_DIAG_MASK_INF:
8814 case GGML_OP_ROPE:
8815 case GGML_OP_ROPE_BACK:
8816 elements = { (uint32_t)ggml_nrows(tensor: src0), (uint32_t)ne00, 1 };
8817 break;
8818 case GGML_OP_GET_ROWS:
8819 elements = { (uint32_t)ne00, (uint32_t)ne10, (uint32_t)(ne11 * ne12) };
8820 elements[1] = std::min(a: elements[1], b: ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
8821 elements[2] = std::min(a: elements[2], b: ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
8822 break;
8823 case GGML_OP_ARGSORT:
8824 elements = { (uint32_t)ne00, (uint32_t)ggml_nrows(tensor: src0), 1 };
8825 elements[1] = std::min(a: elements[1], b: ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
8826 break;
8827 case GGML_OP_IM2COL:
8828 {
8829 const bool is_2D = dst->op_params[6] == 1;
8830
8831 const uint32_t IC = src1->ne[is_2D ? 2 : 1];
8832
8833 const uint32_t KH = is_2D ? src0->ne[1] : 1;
8834 const uint32_t KW = src0->ne[0];
8835
8836 const uint32_t OH = is_2D ? dst->ne[2] : 1;
8837 const uint32_t OW = dst->ne[1];
8838
8839 const uint32_t batch = src1->ne[is_2D ? 3 : 2];
8840
8841 elements = { OW * KW * KH, OH, batch * IC };
8842 } break;
8843 case GGML_OP_IM2COL_3D:
8844 {
8845 const uint32_t IC = ((const uint32_t *)(dst->op_params))[9];
8846
8847 const uint32_t N = ne13 / IC;
8848
8849 const uint32_t KD = ne02;
8850 const uint32_t KH = ne01;
8851 const uint32_t KW = ne00;
8852
8853 const uint32_t OD = dst->ne[3] / N;
8854 const uint32_t OH = dst->ne[2];
8855 const uint32_t OW = dst->ne[1];
8856
8857 const uint32_t IC_KD_KH_KW = IC*KD*KH*KW;
8858 const uint32_t N_OD_OH = N*OD*OH;
8859
8860 elements = { IC_KD_KH_KW, OW, N_OD_OH };
8861 elements[2] = std::min(a: elements[2], b: ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
8862 } break;
8863 case GGML_OP_TIMESTEP_EMBEDDING:
8864 {
8865 const uint32_t dim = dst->op_params[0];
8866 uint32_t half_ceil = (dim + 1) / 2;
8867 elements = { half_ceil, (uint32_t)src0->ne[0], 1 };
8868 } break;
8869 case GGML_OP_CONV_TRANSPOSE_1D:
8870 {
8871 elements = {uint32_t(src0->ne[1]), 1, 1}; // parallelize in {Cout, 1, 1}
8872 } break;
8873 case GGML_OP_POOL_2D:
8874 {
8875 const uint32_t N = dst->ne[3];
8876 const uint32_t OC = dst->ne[2];
8877 const uint32_t OH = dst->ne[1];
8878 const uint32_t OW = dst->ne[0];
8879 elements = { N * OC * OH * OW, 1, 1};
8880 } break;
8881 case GGML_OP_CONV_2D:
8882 {
8883 elements = ggml_vk_get_conv_elements(dst);
8884 } break;
8885 case GGML_OP_CONV_TRANSPOSE_2D:
8886 {
8887 elements = ggml_vk_get_conv_transpose_2d_elements(dst);
8888 } break;
8889 case GGML_OP_ADD:
8890 case GGML_OP_SUB:
8891 case GGML_OP_DIV:
8892 case GGML_OP_MUL:
8893 case GGML_OP_SCALE:
8894 case GGML_OP_SQR:
8895 case GGML_OP_SQRT:
8896 case GGML_OP_SIN:
8897 case GGML_OP_COS:
8898 case GGML_OP_CLAMP:
8899 case GGML_OP_PAD:
8900 case GGML_OP_ROLL:
8901 case GGML_OP_REPEAT:
8902 case GGML_OP_REPEAT_BACK:
8903 case GGML_OP_CPY:
8904 case GGML_OP_CONCAT:
8905 case GGML_OP_UPSCALE:
8906 case GGML_OP_UNARY:
8907 case GGML_OP_GLU:
8908 case GGML_OP_CONV_2D_DW:
8909 {
8910 uint32_t ne = ggml_nelements(tensor: dst);
8911 if (op == GGML_OP_CPY && ggml_is_quantized(type: src0->type) && ggml_is_quantized(type: dst->type)) {
8912 // Convert from number of logical elements to 2- or 4-byte units.
8913 ne /= ggml_blck_size(type: src0->type);
8914 if ((ggml_type_size(type: src0->type) % 4) == 0) {
8915 ne *= ggml_type_size(type: src0->type) / 4;
8916 } else {
8917 ne *= ggml_type_size(type: src0->type) / 2;
8918 }
8919 }
8920 // copy_to_quant has block size of 32, and each thread does QUANT_K elements.
8921 // Splitting into 512x512xZ wouldn't work well since each workgroup does 1024 elements.
8922 // So divide by block size here before splitting into 512x512 groups.
8923 if (op == GGML_OP_CPY && !ggml_is_quantized(type: src0->type) && ggml_is_quantized(type: dst->type)) {
8924 ne = CEIL_DIV(ne, ggml_blck_size(dst->type));
8925 }
8926 if (ne > 262144) {
8927 elements = { 512, 512, CEIL_DIV(ne, 262144) };
8928 } else if (ne > 512) {
8929 elements = { 512, CEIL_DIV(ne, 512), 1 };
8930 } else {
8931 elements = { ne, 1, 1 };
8932 }
8933 } break;
8934 case GGML_OP_ADD_ID:
8935 {
8936 elements = { (uint32_t)ne01, (uint32_t)ne02, 1 };
8937 } break;
8938 case GGML_OP_SET_ROWS:
8939 {
8940 uint32_t ne = ggml_nelements(tensor: src0);
8941 if (ggml_is_quantized(type: dst->type)) {
8942 // quants run 32 threads each doing QUANT_K elements
8943 ne = CEIL_DIV(ne, 32 * ggml_blck_size(dst->type));
8944 } else {
8945 // scalar types do one element per thread, running 512 threads
8946 ne = CEIL_DIV(ne, 512);
8947 }
8948 if (ne > 262144) {
8949 elements = { 512, 512, CEIL_DIV(ne, 262144) };
8950 } else if (ne > 512) {
8951 elements = { 512, CEIL_DIV(ne, 512), 1 };
8952 } else {
8953 elements = { ne, 1, 1 };
8954 }
8955 }
8956 break;
8957 case GGML_OP_SSM_CONV:
8958 {
8959 const uint32_t nr = src0->ne[1];
8960 const uint32_t n_t = dst->ne[1];
8961 const uint32_t n_s = dst->ne[2];
8962 elements = { nr, n_t, n_s };
8963 }
8964 break;
8965 default:
8966 elements = { (uint32_t)ggml_nelements(tensor: src0), 1, 1 };
8967 break;
8968 }
8969
8970 if (op == GGML_OP_ADD || op == GGML_OP_RMS_NORM) {
8971 vk_subbuffer a_buf = src0_buf;
8972 if (ctx->do_add_rms_partials) {
8973 a_buf = ggml_vk_subbuffer(ctx, buf: ctx->prealloc_add_rms_partials, offset: ctx->prealloc_size_add_rms_partials_offset);
8974 }
8975 ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
8976 { src0_buf, src1_buf, dst_buf, a_buf }, pc, elements);
8977 } else if (op == GGML_OP_GLU) {
8978 // Empty src1 is possible in glu, but the shader needs a buffer
8979 vk_subbuffer subbuf1 = use_src1 ? src1_buf : src0_buf;
8980 ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, subbuf1, dst_buf }, pc, elements);
8981 } else if (op == GGML_OP_SOFT_MAX) {
8982 // Empty src1 and src2 is possible in soft_max, but the shader needs a buffer
8983 vk_subbuffer subbuf1 = use_src1 ? src1_buf : src0_buf;
8984 vk_subbuffer subbuf2 = use_src2 ? src2_buf : src0_buf;
8985 ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, subbuf1, subbuf2, dst_buf }, pc, elements);
8986 } else if (op == GGML_OP_ROPE || op == GGML_OP_ROPE_BACK) {
8987 // Empty src2 and src3 is possible in rope, but the shader needs a buffer
8988 vk_subbuffer subbuf2 = use_src2 ? src2_buf : src0_buf;
8989 vk_subbuffer subbuf3 = use_src3 ? src3_buf : src0_buf;
8990 ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, src1_buf, subbuf2, dst_buf, subbuf3 }, pc, elements);
8991 } else if (op == GGML_OP_IM2COL || op == GGML_OP_IM2COL_3D) {
8992 if (ctx->device->shader_int64 && ctx->device->buffer_device_address) {
8993 // buffer device address path doesn't use dst buffer
8994 dst_buf.size = 1;
8995 }
8996 // im2col uses only src1 and dst buffers
8997 ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src1_buf, dst_buf }, pc, elements);
8998 } else if (op == GGML_OP_COUNT_EQUAL) {
8999 // count_equal assumes that destination buffer is initialized with zeroes
9000 ggml_vk_buffer_memset_async(ctx&: subctx, dst&: dst_buf.buffer, offset: dst_buf.offset, c: 0, size: dst_buf.size);
9001 ggml_vk_sync_buffers(ctx, subctx);
9002 ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, src1_buf, dst_buf }, pc, elements);
9003 } else if (op == GGML_OP_OPT_STEP_SGD) {
9004 // OPT_STEP_SGD works on src0, it does not need dst
9005 ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, src1_buf, src2_buf }, pc, elements);
9006 } else if (use_src3) {
9007 ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, src1_buf, src2_buf, src3_buf, dst_buf }, pc, elements);
9008 } else if (use_src2) {
9009 ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, src1_buf, src2_buf, dst_buf }, pc, elements);
9010 } else if (use_src1) {
9011 ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, src1_buf, dst_buf }, pc, elements);
9012 } else {
9013 ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, dst_buf }, pc, elements);
9014 }
9015}
9016
9017static void ggml_vk_get_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
9018 const uint32_t src0_type_size = ggml_type_size(type: src0->type);
9019 const uint32_t src1_type_size = ggml_type_size(type: src1->type);
9020 const uint32_t dst_type_size = ggml_type_size(type: dst->type);
9021
9022 ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, src2: nullptr, src3: nullptr, dst, op: GGML_OP_GET_ROWS, pc: {
9023 .ne: (uint32_t)ggml_nelements(tensor: src0),
9024 .ne00: (uint32_t)src0->ne[0], .ne01: (uint32_t)src0->ne[1], .ne02: (uint32_t)src0->ne[2],.ne03: (uint32_t)src0->ne[3], .nb00: (uint32_t)src0->nb[0] / src0_type_size, .nb01: (uint32_t)src0->nb[1] / src0_type_size, .nb02: (uint32_t)src0->nb[2] / src0_type_size, .nb03: (uint32_t)src0->nb[3] / src0_type_size,
9025 .ne10: (uint32_t)src1->ne[0], .ne11: (uint32_t)src1->ne[1], .ne12: (uint32_t)src1->ne[2],.ne13: (uint32_t)src1->ne[3], .nb10: (uint32_t)src1->nb[0] / src1_type_size, .nb11: (uint32_t)src1->nb[1] / src1_type_size, .nb12: (uint32_t)src1->nb[2] / src1_type_size, .nb13: (uint32_t)src1->nb[3] / src1_type_size,
9026 .ne20: (uint32_t) dst->ne[0], .ne21: (uint32_t) dst->ne[1], .ne22: (uint32_t) dst->ne[2],.ne23: (uint32_t) dst->ne[3], .nb20: (uint32_t) dst->nb[0] / dst_type_size, .nb21: (uint32_t) dst->nb[1] / dst_type_size, .nb22: (uint32_t) dst->nb[2] / dst_type_size, .nb23: (uint32_t) dst->nb[3] / dst_type_size,
9027 .misalign_offsets: 0,
9028 .param1: 0.0f, .param2: 0.0f, .param3: 0,
9029 });
9030}
9031
9032static void ggml_vk_acc(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
9033 const uint32_t src0_type_size = ggml_type_size(type: src0->type);
9034 const uint32_t src1_type_size = ggml_type_size(type: src1->type);
9035 const uint32_t dst_type_size = ggml_type_size(type: dst->type);
9036
9037 int nb1 = dst->op_params[0] / 4; // 4 bytes of float32
9038 int nb2 = dst->op_params[1] / 4; // 4 bytes of float32
9039 // int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused
9040 int offset = dst->op_params[3] / 4; // offset in bytes
9041
9042 ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, src2: nullptr, src3: nullptr, dst, op: GGML_OP_ACC, pc: {
9043 .ne: (uint32_t)ggml_nelements(tensor: src0),
9044 .ne00: (uint32_t)src0->ne[0], .ne01: (uint32_t)src0->ne[1], .ne02: (uint32_t)src0->ne[2],.ne03: (uint32_t)src0->ne[3], .nb00: (uint32_t)src0->nb[0] / src0_type_size, .nb01: (uint32_t)nb1, .nb02: (uint32_t)nb2, .nb03: (uint32_t)src0->nb[3] / src0_type_size,
9045 .ne10: (uint32_t)src1->ne[0], .ne11: (uint32_t)src1->ne[1], .ne12: (uint32_t)src1->ne[2],.ne13: (uint32_t)src1->ne[3], .nb10: (uint32_t)src1->nb[0] / src1_type_size, .nb11: (uint32_t)src1->nb[1] / src1_type_size, .nb12: (uint32_t)src1->nb[2] / src1_type_size, .nb13: (uint32_t)src1->nb[3] / src1_type_size,
9046 .ne20: (uint32_t) dst->ne[0], .ne21: (uint32_t) dst->ne[1], .ne22: (uint32_t) dst->ne[2],.ne23: (uint32_t) dst->ne[3], .nb20: (uint32_t) dst->nb[0] / dst_type_size, .nb21: (uint32_t)nb1, .nb22: (uint32_t)nb2, .nb23: (uint32_t) dst->nb[3] / dst_type_size,
9047 .misalign_offsets: 0,
9048 .param1: 0.0f, .param2: 0.0f, .param3: offset,
9049 });
9050}
9051
9052static void ggml_vk_multi_add(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx) {
9053 const ggml_tensor *first_node = cgraph->nodes[node_idx];
9054 const ggml_tensor *dst = cgraph->nodes[node_idx + ctx->num_additional_fused_ops];
9055
9056 // Make a list of all the tensors used by the op.
9057 // Last element of the list is the dest tensor.
9058 const ggml_tensor *tensors[MAX_PARAMETER_COUNT];
9059 uint32_t num_srcs = ctx->num_additional_fused_ops + 2;
9060 uint32_t num_tensors = num_srcs + 1;
9061 GGML_ASSERT(num_tensors + ctx->do_add_rms_partials <= MAX_PARAMETER_COUNT);
9062
9063 tensors[0] = first_node->src[0];
9064 tensors[1] = first_node->src[1];
9065 for (int32_t i = 0; i < ctx->num_additional_fused_ops; ++i) {
9066 // check whether the previous result is src[0] or src[1]
9067 if (cgraph->nodes[node_idx + i] == cgraph->nodes[node_idx + i + 1]->src[0]) {
9068 tensors[i+2] = cgraph->nodes[node_idx + i + 1]->src[1];
9069 } else {
9070 tensors[i+2] = cgraph->nodes[node_idx + i + 1]->src[0];
9071 }
9072 }
9073 tensors[num_srcs] = dst;
9074
9075 vk_op_multi_add_push_constants pc;
9076 pc.ne20 = (uint32_t)dst->ne[0];
9077 pc.ne21 = (uint32_t)dst->ne[1];
9078 pc.ne22 = (uint32_t)dst->ne[2];
9079 pc.ne23 = (uint32_t)dst->ne[3];
9080
9081 for (uint32_t i = 0; i < num_tensors; ++i) {
9082 const ggml_tensor *t = tensors[i];
9083 pc.nb[i][0] = (uint32_t)t->nb[0] / sizeof(float);
9084 pc.nb[i][1] = (uint32_t)t->nb[1] / sizeof(float);
9085 pc.nb[i][2] = (uint32_t)t->nb[2] / sizeof(float);
9086 pc.nb[i][3] = (uint32_t)t->nb[3] / sizeof(float);
9087 }
9088 pc.rms_partials = ctx->do_add_rms_partials;
9089
9090 vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0: tensors[0], src1: tensors[1], src2: nullptr, dst, op: dst->op);
9091
9092 if (pipeline == nullptr) {
9093 std::cerr << "ggml_vulkan: Error: Missing multi_add";
9094 GGML_ABORT("fatal error");
9095 }
9096
9097 ggml_pipeline_request_descriptor_sets(ctx, pipeline, n: 1);
9098
9099 ggml_backend_vk_buffer_context * buf_ctx[MAX_PARAMETER_COUNT];
9100 vk_buffer buf[MAX_PARAMETER_COUNT];
9101 size_t offset[MAX_PARAMETER_COUNT];
9102 bool uma[MAX_PARAMETER_COUNT];
9103
9104 for (uint32_t i = 0; i < num_tensors; ++i) {
9105 buf_ctx[i] = (ggml_backend_vk_buffer_context *)tensors[i]->buffer->context;
9106 buf[i] = nullptr;
9107 offset[i] = 0;
9108 uma[i] = false;
9109
9110 if (ctx->device->uma) {
9111 ggml_vk_host_get(device: ctx->device, ptr: tensors[i]->data, buf&: buf[i], buf_offset&: offset[i]);
9112 uma[i] = buf[i] != nullptr;
9113 }
9114 if (!uma[i]) {
9115 buf[i] = buf_ctx[i]->dev_buffer;
9116 offset[i] = vk_tensor_offset(tensor: tensors[i]) + tensors[i]->view_offs;
9117 }
9118 GGML_ASSERT(buf[i] != nullptr);
9119 }
9120 // If any remaining descriptors are unused, just point them at src[0]
9121 for (uint32_t i = num_tensors; i < MAX_PARAMETER_COUNT; ++i) {
9122 buf[i] = buf[0];
9123 offset[i] = 0;
9124 }
9125 if (ctx->do_add_rms_partials) {
9126 buf[num_tensors] = ctx->prealloc_add_rms_partials;
9127 offset[num_tensors] = ctx->prealloc_size_add_rms_partials_offset;
9128 }
9129
9130 std::array<uint32_t, 3> elements;
9131
9132 uint32_t ne = ggml_nelements(tensor: dst);
9133 if (ne > 262144) {
9134 elements = { 512, 512, CEIL_DIV(ne, 262144) };
9135 } else if (ne > 512) {
9136 elements = { 512, CEIL_DIV(ne, 512), 1 };
9137 } else {
9138 elements = { ne, 1, 1 };
9139 }
9140
9141 static_assert(MAX_PARAMETER_COUNT == 12);
9142 ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
9143 descriptor_buffer_infos: {
9144 ggml_vk_subbuffer(ctx, buf: buf[0], offset: offset[0]),
9145 ggml_vk_subbuffer(ctx, buf: buf[1], offset: offset[1]),
9146 ggml_vk_subbuffer(ctx, buf: buf[2], offset: offset[2]),
9147 ggml_vk_subbuffer(ctx, buf: buf[3], offset: offset[3]),
9148 ggml_vk_subbuffer(ctx, buf: buf[4], offset: offset[4]),
9149 ggml_vk_subbuffer(ctx, buf: buf[5], offset: offset[5]),
9150 ggml_vk_subbuffer(ctx, buf: buf[6], offset: offset[6]),
9151 ggml_vk_subbuffer(ctx, buf: buf[7], offset: offset[7]),
9152 ggml_vk_subbuffer(ctx, buf: buf[8], offset: offset[8]),
9153 ggml_vk_subbuffer(ctx, buf: buf[9], offset: offset[9]),
9154 ggml_vk_subbuffer(ctx, buf: buf[10], offset: offset[10]),
9155 ggml_vk_subbuffer(ctx, buf: buf[11], offset: offset[11]),
9156 }, push_constants: pc, elements);
9157}
9158
9159static void ggml_vk_add(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
9160 const uint32_t src0_type_size = ggml_type_size(type: src0->type);
9161 const uint32_t src1_type_size = ggml_type_size(type: src1->type);
9162 const uint32_t dst_type_size = ggml_type_size(type: dst->type);
9163
9164 ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, src2: nullptr, src3: nullptr, dst, op: GGML_OP_ADD, pc: {
9165 .ne: (uint32_t)ggml_nelements(tensor: src0),
9166 .ne00: (uint32_t)src0->ne[0], .ne01: (uint32_t)src0->ne[1], .ne02: (uint32_t)src0->ne[2],.ne03: (uint32_t)src0->ne[3], .nb00: (uint32_t)src0->nb[0] / src0_type_size, .nb01: (uint32_t)src0->nb[1] / src0_type_size, .nb02: (uint32_t)src0->nb[2] / src0_type_size, .nb03: (uint32_t)src0->nb[3] / src0_type_size,
9167 .ne10: (uint32_t)src1->ne[0], .ne11: (uint32_t)src1->ne[1], .ne12: (uint32_t)src1->ne[2],.ne13: (uint32_t)src1->ne[3], .nb10: (uint32_t)src1->nb[0] / src1_type_size, .nb11: (uint32_t)src1->nb[1] / src1_type_size, .nb12: (uint32_t)src1->nb[2] / src1_type_size, .nb13: (uint32_t)src1->nb[3] / src1_type_size,
9168 .ne20: (uint32_t) dst->ne[0], .ne21: (uint32_t) dst->ne[1], .ne22: (uint32_t) dst->ne[2],.ne23: (uint32_t) dst->ne[3], .nb20: (uint32_t) dst->nb[0] / dst_type_size, .nb21: (uint32_t) dst->nb[1] / dst_type_size, .nb22: (uint32_t) dst->nb[2] / dst_type_size, .nb23: (uint32_t) dst->nb[3] / dst_type_size,
9169 .misalign_offsets: 0,
9170 .param1: 0.0f, .param2: 0.0f, .param3: ctx->do_add_rms_partials,
9171 });
9172}
9173
9174static void ggml_vk_sub(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
9175 const uint32_t src0_type_size = ggml_type_size(type: src0->type);
9176 const uint32_t src1_type_size = ggml_type_size(type: src1->type);
9177 const uint32_t dst_type_size = ggml_type_size(type: dst->type);
9178
9179 ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, src2: nullptr, src3: nullptr, dst, op: GGML_OP_SUB, pc: {
9180 .ne: (uint32_t)ggml_nelements(tensor: src0),
9181 .ne00: (uint32_t)src0->ne[0], .ne01: (uint32_t)src0->ne[1], .ne02: (uint32_t)src0->ne[2],.ne03: (uint32_t)src0->ne[3], .nb00: (uint32_t)src0->nb[0] / src0_type_size, .nb01: (uint32_t)src0->nb[1] / src0_type_size, .nb02: (uint32_t)src0->nb[2] / src0_type_size, .nb03: (uint32_t)src0->nb[3] / src0_type_size,
9182 .ne10: (uint32_t)src1->ne[0], .ne11: (uint32_t)src1->ne[1], .ne12: (uint32_t)src1->ne[2],.ne13: (uint32_t)src1->ne[3], .nb10: (uint32_t)src1->nb[0] / src1_type_size, .nb11: (uint32_t)src1->nb[1] / src1_type_size, .nb12: (uint32_t)src1->nb[2] / src1_type_size, .nb13: (uint32_t)src1->nb[3] / src1_type_size,
9183 .ne20: (uint32_t) dst->ne[0], .ne21: (uint32_t) dst->ne[1], .ne22: (uint32_t) dst->ne[2],.ne23: (uint32_t) dst->ne[3], .nb20: (uint32_t) dst->nb[0] / dst_type_size, .nb21: (uint32_t) dst->nb[1] / dst_type_size, .nb22: (uint32_t) dst->nb[2] / dst_type_size, .nb23: (uint32_t) dst->nb[3] / dst_type_size,
9184 .misalign_offsets: 0,
9185 .param1: 0.0f, .param2: 0.0f, .param3: 0,
9186 });
9187}
9188
9189static void ggml_vk_mul(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
9190 const uint32_t src0_type_size = ggml_type_size(type: src0->type);
9191 const uint32_t src1_type_size = ggml_type_size(type: src1->type);
9192 const uint32_t dst_type_size = ggml_type_size(type: dst->type);
9193
9194 ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, src2: nullptr, src3: nullptr, dst, op: GGML_OP_MUL, pc: {
9195 .ne: (uint32_t)ggml_nelements(tensor: src0),
9196 .ne00: (uint32_t)src0->ne[0], .ne01: (uint32_t)src0->ne[1], .ne02: (uint32_t)src0->ne[2],.ne03: (uint32_t)src0->ne[3], .nb00: (uint32_t)src0->nb[0] / src0_type_size, .nb01: (uint32_t)src0->nb[1] / src0_type_size, .nb02: (uint32_t)src0->nb[2] / src0_type_size, .nb03: (uint32_t)src0->nb[3] / src0_type_size,
9197 .ne10: (uint32_t)src1->ne[0], .ne11: (uint32_t)src1->ne[1], .ne12: (uint32_t)src1->ne[2],.ne13: (uint32_t)src1->ne[3], .nb10: (uint32_t)src1->nb[0] / src1_type_size, .nb11: (uint32_t)src1->nb[1] / src1_type_size, .nb12: (uint32_t)src1->nb[2] / src1_type_size, .nb13: (uint32_t)src1->nb[3] / src1_type_size,
9198 .ne20: (uint32_t) dst->ne[0], .ne21: (uint32_t) dst->ne[1], .ne22: (uint32_t) dst->ne[2],.ne23: (uint32_t) dst->ne[3], .nb20: (uint32_t) dst->nb[0] / dst_type_size, .nb21: (uint32_t) dst->nb[1] / dst_type_size, .nb22: (uint32_t) dst->nb[2] / dst_type_size, .nb23: (uint32_t) dst->nb[3] / dst_type_size,
9199 .misalign_offsets: 0,
9200 .param1: 0.0f, .param2: 0.0f, .param3: 0,
9201 });
9202}
9203
9204static void ggml_vk_div(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
9205 const uint32_t src0_type_size = ggml_type_size(type: src0->type);
9206 const uint32_t src1_type_size = ggml_type_size(type: src1->type);
9207 const uint32_t dst_type_size = ggml_type_size(type: dst->type);
9208
9209 ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, src2: nullptr, src3: nullptr, dst, op: GGML_OP_DIV, pc: {
9210 .ne: (uint32_t)ggml_nelements(tensor: src0),
9211 .ne00: (uint32_t)src0->ne[0], .ne01: (uint32_t)src0->ne[1], .ne02: (uint32_t)src0->ne[2],.ne03: (uint32_t)src0->ne[3], .nb00: (uint32_t)src0->nb[0] / src0_type_size, .nb01: (uint32_t)src0->nb[1] / src0_type_size, .nb02: (uint32_t)src0->nb[2] / src0_type_size, .nb03: (uint32_t)src0->nb[3] / src0_type_size,
9212 .ne10: (uint32_t)src1->ne[0], .ne11: (uint32_t)src1->ne[1], .ne12: (uint32_t)src1->ne[2],.ne13: (uint32_t)src1->ne[3], .nb10: (uint32_t)src1->nb[0] / src1_type_size, .nb11: (uint32_t)src1->nb[1] / src1_type_size, .nb12: (uint32_t)src1->nb[2] / src1_type_size, .nb13: (uint32_t)src1->nb[3] / src1_type_size,
9213 .ne20: (uint32_t) dst->ne[0], .ne21: (uint32_t) dst->ne[1], .ne22: (uint32_t) dst->ne[2],.ne23: (uint32_t) dst->ne[3], .nb20: (uint32_t) dst->nb[0] / dst_type_size, .nb21: (uint32_t) dst->nb[1] / dst_type_size, .nb22: (uint32_t) dst->nb[2] / dst_type_size, .nb23: (uint32_t) dst->nb[3] / dst_type_size,
9214 .misalign_offsets: 0,
9215 .param1: 0.0f, .param2: 0.0f, .param3: 0,
9216 });
9217}
9218
9219static void ggml_vk_add_id(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
9220 const uint32_t src0_type_size = ggml_type_size(type: src0->type);
9221 const uint32_t src1_type_size = ggml_type_size(type: src1->type);
9222 const uint32_t src2_type_size = ggml_type_size(type: src2->type);
9223
9224 ggml_vk_op_f32<vk_op_add_id_push_constants>(ctx, subctx, src0, src1, src2, src3: nullptr, dst, op: GGML_OP_ADD_ID, pc: {
9225 .ne0: (uint32_t)dst->ne[0],
9226 .ne1: (uint32_t)dst->ne[1],
9227 .s01: (uint32_t)src0->nb[1] / src0_type_size,
9228 .s02: (uint32_t)src0->nb[2] / src0_type_size,
9229 .s11: (uint32_t)src1->nb[1] / src1_type_size,
9230 .s21: (uint32_t)src2->nb[1] / src2_type_size,
9231 });
9232}
9233
9234static void ggml_vk_op_f32_wkv(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_rwkv_wkv6_push_constants&& pc, int version) {
9235 GGML_ASSERT(version == 6 || version == 7);
9236 int num_srcs = version == 6 ? 6 : 7;
9237
9238 for (int i = 0; i < num_srcs; i++) {
9239 GGML_ASSERT(!ggml_is_quantized(dst->src[i]->type));
9240 }
9241
9242 GGML_ASSERT(dst->buffer != nullptr);
9243
9244 vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0: dst->src[0], src1: dst->src[1], src2: dst->src[2], dst, op: dst->op);
9245 GGML_ASSERT(pipeline != nullptr);
9246
9247 ggml_pipeline_request_descriptor_sets(ctx, pipeline, n: 1);
9248
9249 vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, tensor: dst);
9250 vk_subbuffer src_buf[7] = {};
9251 for (int i = 0; i < num_srcs; i++) {
9252 src_buf[i] = ggml_vk_tensor_subbuffer(ctx, tensor: dst->src[i]);
9253 }
9254
9255 std::array<uint32_t, 3> elements = {
9256 (uint32_t)(pc.B * pc.H),
9257 1,
9258 1
9259 };
9260
9261 if (version == 6) {
9262 ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
9263 descriptor_buffer_infos: {src_buf[0], src_buf[1], src_buf[2], src_buf[3], src_buf[4], src_buf[5], dst_buf},
9264 push_constants: pc, elements);
9265 } else if (version == 7) {
9266 ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
9267 descriptor_buffer_infos: {src_buf[0], src_buf[1], src_buf[2], src_buf[3], src_buf[4], src_buf[5], src_buf[6], dst_buf},
9268 push_constants: pc, elements);
9269 } else {
9270 // shouldn't happen
9271 GGML_ASSERT(false);
9272 }
9273}
9274
9275static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
9276 const size_t seq_length = dst->src[0]->ne[2];
9277 const size_t n_embed = dst->ne[0];
9278 const size_t n_heads = dst->src[0]->ne[1];
9279 const size_t n_seqs = dst->src[5]->ne[1];
9280
9281 ggml_vk_op_f32_wkv(
9282 ctx, subctx, dst,
9283 pc: {
9284 .B: (uint32_t)n_seqs,
9285 .T: (uint32_t)seq_length,
9286 .C: (uint32_t)n_embed,
9287 .H: (uint32_t)n_heads,
9288 },
9289 version: 6
9290 );
9291}
9292
9293static void ggml_vk_rwkv_wkv7(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
9294 const size_t seq_length = dst->src[0]->ne[2];
9295 const size_t n_embed = dst->ne[0];
9296 const size_t n_heads = dst->src[0]->ne[1];
9297 const size_t n_seqs = dst->src[6]->ne[1];
9298
9299 ggml_vk_op_f32_wkv(
9300 ctx, subctx, dst,
9301 pc: {
9302 .B: (uint32_t)n_seqs,
9303 .T: (uint32_t)seq_length,
9304 .C: (uint32_t)n_embed,
9305 .H: (uint32_t)n_heads,
9306 },
9307 version: 7
9308 );
9309}
9310
9311static void ggml_vk_ssm_scan(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
9312 const ggml_tensor * src0 = dst->src[0];
9313 const ggml_tensor * src1 = dst->src[1];
9314 const ggml_tensor * src2 = dst->src[2];
9315 const ggml_tensor * src3 = dst->src[3];
9316 const ggml_tensor * src4 = dst->src[4];
9317 const ggml_tensor * src5 = dst->src[5];
9318
9319 GGML_ASSERT(dst->buffer != nullptr);
9320
9321 const uint32_t head_dim = src0->ne[1];
9322 const uint32_t n_head = src1->ne[1];
9323 const uint32_t n_group = src4->ne[1];
9324 const uint32_t n_tok = src1->ne[2];
9325 const uint32_t n_seq = src1->ne[3];
9326
9327 bool is_mamba2 = (src3->nb[1] == sizeof(float));
9328 GGML_ASSERT(is_mamba2);
9329
9330 vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, dst, op: dst->op);
9331 GGML_ASSERT(pipeline != nullptr);
9332
9333 ggml_pipeline_request_descriptor_sets(ctx, pipeline, n: 1);
9334
9335 const int64_t s_off = ggml_nelements(tensor: src1) * sizeof(float);
9336
9337 const vk_op_ssm_scan_push_constants pc = {
9338 .nb02: (uint32_t)src0->nb[2], .nb03: (uint32_t)src0->nb[3],
9339 .nb12: (uint32_t)src1->nb[2], .nb13: (uint32_t)src1->nb[3],
9340 .nb21: (uint32_t)src2->nb[1], .nb22: (uint32_t)src2->nb[2],
9341 .nb31: (uint32_t)src3->nb[1],
9342 .nb42: (uint32_t)src4->nb[2], .nb43: (uint32_t)src4->nb[3],
9343 .nb52: (uint32_t)src5->nb[2], .nb53: (uint32_t)src5->nb[3],
9344 .s_off: (uint32_t)s_off,
9345 .n_head: n_head, .d_head: head_dim, .n_group: n_group, .n_tok: n_tok
9346 };
9347
9348 vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, tensor: dst);
9349 vk_subbuffer src_buf[7] = {};
9350 for (int i = 0; i < 7 && dst->src[i] != nullptr; i++) {
9351 src_buf[i] = ggml_vk_tensor_subbuffer(ctx, tensor: dst->src[i]);
9352 }
9353
9354 std::array<uint32_t, 3> elements;
9355
9356 const int splitH = 16;
9357 const uint32_t num_workgroups_x = CEIL_DIV(n_head * head_dim, splitH);
9358 const uint32_t num_workgroups_y = n_seq;
9359 elements = { num_workgroups_x, num_workgroups_y, 1 };
9360
9361 ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
9362 descriptor_buffer_infos: {src_buf[0], src_buf[1], src_buf[2], src_buf[3], src_buf[4], src_buf[5], src_buf[6], dst_buf},
9363 push_constants: pc, elements);
9364}
9365
9366static void ggml_vk_ssm_conv(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
9367 const ggml_tensor * src0 = dst->src[0];
9368 const ggml_tensor * src1 = dst->src[1];
9369
9370 ggml_vk_op_f32<vk_op_ssm_conv_push_constants>(ctx, subctx, src0, src1, src2: nullptr, src3: nullptr, dst, op: GGML_OP_SSM_CONV, pc: {
9371 .nb01: (uint32_t)src0->nb[1], .nb02: (uint32_t)src0->nb[2],
9372 .nb11: (uint32_t)src1->nb[1],
9373 .dst_nb0: (uint32_t)dst->nb[0], .dst_nb1: (uint32_t)dst->nb[1], .dst_nb2: (uint32_t)dst->nb[2],
9374 .nc: (uint32_t)src1->ne[0],
9375 .ncs: (uint32_t)src0->ne[0],
9376 .nr: (uint32_t)src0->ne[1],
9377 .n_t: (uint32_t)dst->ne[1],
9378 .n_s: (uint32_t)dst->ne[2],
9379 });
9380}
9381
9382static void ggml_vk_op_f32_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_push_constants&& pc) {
9383 const ggml_tensor * x = dst->src[0];
9384 const ggml_tensor * g = dst->src[1];
9385 const ggml_tensor * gm = dst->src[2];
9386 const ggml_tensor * gv = dst->src[3];
9387 const ggml_tensor * p = dst->src[4];
9388
9389 GGML_ASSERT(x->type == GGML_TYPE_F32);
9390 GGML_ASSERT(g->type == GGML_TYPE_F32);
9391 GGML_ASSERT(gm->type == GGML_TYPE_F32);
9392 GGML_ASSERT(gv->type == GGML_TYPE_F32);
9393 GGML_ASSERT(p->type == GGML_TYPE_F32);
9394 GGML_ASSERT(dst->buffer != nullptr);
9395 GGML_ASSERT(ggml_is_contiguous(x));
9396 GGML_ASSERT(ggml_is_contiguous(g));
9397 GGML_ASSERT(ggml_is_contiguous(gm));
9398 GGML_ASSERT(ggml_is_contiguous(gv));
9399 GGML_ASSERT(ggml_is_contiguous(p));
9400 GGML_ASSERT(ggml_are_same_shape(x, g));
9401 GGML_ASSERT(ggml_are_same_shape(x, gm));
9402 GGML_ASSERT(ggml_are_same_shape(x, gv));
9403 GGML_ASSERT(ggml_nelements(p) == 7);
9404
9405 vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0: g, src1: gm, src2: gv, dst, op: GGML_OP_OPT_STEP_ADAMW);
9406 GGML_ASSERT(pipeline != nullptr);
9407
9408 ggml_pipeline_request_descriptor_sets(ctx, pipeline, n: 1);
9409
9410 vk_subbuffer x_buf = ggml_vk_tensor_subbuffer(ctx, tensor: x);
9411 vk_subbuffer g_buf = ggml_vk_tensor_subbuffer(ctx, tensor: g);
9412 vk_subbuffer gm_buf = ggml_vk_tensor_subbuffer(ctx, tensor: gm);
9413 vk_subbuffer gv_buf = ggml_vk_tensor_subbuffer(ctx, tensor: gv);
9414 vk_subbuffer p_buf = ggml_vk_tensor_subbuffer(ctx, tensor: p);
9415
9416 std::array<uint32_t, 3> elements = { (uint32_t)ggml_nelements(tensor: x), 1, 1 };
9417
9418 ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
9419 descriptor_buffer_infos: {x_buf, g_buf, gm_buf, gv_buf, p_buf},
9420 push_constants: pc, elements);
9421}
9422
9423static void ggml_vk_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
9424 const size_t n = ggml_nelements(tensor: dst->src[0]);
9425
9426 ggml_vk_op_f32_opt_step_adamw(
9427 ctx, subctx, dst,
9428 pc: { .KX: (uint32_t)n, .KY: 0, .param1: 0.0f, .param2: 0.0f }
9429 );
9430}
9431
9432static void ggml_vk_opt_step_sgd(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
9433 const size_t n = ggml_nelements(tensor: dst->src[0]);
9434
9435 ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, src2, src3: nullptr, dst, op: GGML_OP_OPT_STEP_SGD, pc: { .KX: (uint32_t)n, .KY: 0, .param1: 0.0f, .param2: 0.0f });
9436}
9437
9438static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
9439 int * op_params = (int *)dst->op_params;
9440
9441 const uint32_t src0_type_size = ggml_type_size(type: src0->type);
9442 const uint32_t src1_type_size = ggml_type_size(type: src1->type);
9443 const uint32_t dst_type_size = ggml_type_size(type: dst->type);
9444
9445 ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, src2: nullptr, src3: nullptr, dst, op: GGML_OP_CONCAT, pc: {
9446 .ne: (uint32_t)ggml_nelements(tensor: dst),
9447 .ne00: (uint32_t)src0->ne[0], .ne01: (uint32_t)src0->ne[1], .ne02: (uint32_t)src0->ne[2],.ne03: (uint32_t)src0->ne[3], .nb00: (uint32_t)src0->nb[0] / src0_type_size, .nb01: (uint32_t)src0->nb[1] / src0_type_size, .nb02: (uint32_t)src0->nb[2] / src0_type_size, .nb03: (uint32_t)src0->nb[3] / src0_type_size,
9448 .ne10: (uint32_t)src1->ne[0], .ne11: (uint32_t)src1->ne[1], .ne12: (uint32_t)src1->ne[2],.ne13: (uint32_t)src1->ne[3], .nb10: (uint32_t)src1->nb[0] / src1_type_size, .nb11: (uint32_t)src1->nb[1] / src1_type_size, .nb12: (uint32_t)src1->nb[2] / src1_type_size, .nb13: (uint32_t)src1->nb[3] / src1_type_size,
9449 .ne20: (uint32_t) dst->ne[0], .ne21: (uint32_t) dst->ne[1], .ne22: (uint32_t) dst->ne[2],.ne23: (uint32_t) dst->ne[3], .nb20: (uint32_t) dst->nb[0] / dst_type_size, .nb21: (uint32_t) dst->nb[1] / dst_type_size, .nb22: (uint32_t) dst->nb[2] / dst_type_size, .nb23: (uint32_t) dst->nb[3] / dst_type_size,
9450 .misalign_offsets: 0,
9451 .param1: 0.0f, .param2: 0.0f, .param3: op_params[0],
9452 });
9453}
9454
9455static void ggml_vk_upscale(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
9456 const uint32_t src0_type_size = ggml_type_size(type: src0->type);
9457 const uint32_t mode = (uint32_t)ggml_get_op_params_i32(tensor: dst, i: 0);
9458
9459 GGML_TENSOR_UNARY_OP_LOCALS
9460
9461 float sf0 = (float)ne0 / ne00;
9462 float sf1 = (float)ne1 / ne01;
9463 float sf2 = (float)ne2 / ne02;
9464 float sf3 = (float)ne3 / ne03;
9465 float pixel_offset = 0.5f;
9466
9467 if (mode & GGML_SCALE_FLAG_ALIGN_CORNERS) {
9468 sf0 = ne0 > 1 && ne00 > 1 ? (float)(ne0 - 1) / (ne00 - 1) : sf0;
9469 sf1 = ne1 > 1 && ne01 > 1 ? (float)(ne1 - 1) / (ne01 - 1) : sf1;
9470 pixel_offset = 0.0f;
9471 }
9472
9473 ggml_vk_op_f32<vk_op_upscale_push_constants>(ctx, subctx, src0, src1: nullptr, src2: nullptr, src3: nullptr, dst, op: GGML_OP_UPSCALE, pc: {
9474 .ne: (uint32_t)ggml_nelements(tensor: dst), .a_offset: 0, .d_offset: 0,
9475 .ne00: (uint32_t)ne00, .ne01: (uint32_t)ne01,
9476 .nb00: (uint32_t)nb00 / src0_type_size, .nb01: (uint32_t)nb01 / src0_type_size, .nb02: (uint32_t)nb02 / src0_type_size, .nb03: (uint32_t)nb03 / src0_type_size,
9477 .ne10: (uint32_t)ne0, .ne11: (uint32_t)ne1, .ne12: (uint32_t)ne2, .ne13: (uint32_t)ne3,
9478 .sf0: sf0, .sf1: sf1, .sf2: sf2, .sf3: sf3, .pixel_offset: pixel_offset
9479 });
9480}
9481
9482static void ggml_vk_scale(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
9483 vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
9484 p.param1 = ggml_get_op_params_f32(tensor: dst, i: 0);
9485 p.param2 = ggml_get_op_params_f32(tensor: dst, i: 1);
9486
9487 ggml_vk_op_f32(ctx, subctx, src0, src1: nullptr, src2: nullptr, src3: nullptr, dst, op: GGML_OP_SCALE, pc: std::move(p));
9488}
9489
9490static void ggml_vk_sqr(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
9491 ggml_vk_op_f32(ctx, subctx, src0, src1: nullptr, src2: nullptr, src3: nullptr, dst, op: GGML_OP_SQR, pc: vk_op_unary_push_constants_init(src0, dst));
9492}
9493
9494static void ggml_vk_sqrt(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
9495 ggml_vk_op_f32(ctx, subctx, src0, src1: nullptr, src2: nullptr, src3: nullptr, dst, op: GGML_OP_SQRT, pc: vk_op_unary_push_constants_init(src0, dst));
9496}
9497
9498static void ggml_vk_sin(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
9499 ggml_vk_op_f32(ctx, subctx, src0, src1: nullptr, src2: nullptr, src3: nullptr, dst, op: GGML_OP_SIN, pc: vk_op_unary_push_constants_init(src0, dst));
9500}
9501
9502static void ggml_vk_cos(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
9503 ggml_vk_op_f32(ctx, subctx, src0, src1: nullptr, src2: nullptr, src3: nullptr, dst, op: GGML_OP_COS, pc: vk_op_unary_push_constants_init(src0, dst));
9504}
9505
9506static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
9507 vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
9508 p.param1 = ggml_get_op_params_f32(tensor: dst, i: 0);
9509 p.param2 = ggml_get_op_params_f32(tensor: dst, i: 1);
9510
9511 ggml_vk_op_f32(ctx, subctx, src0, src1: nullptr, src2: nullptr, src3: nullptr, dst, op: GGML_OP_CLAMP, pc: std::move(p));
9512}
9513
9514static void ggml_vk_pad(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
9515 vk_op_pad_push_constants p = vk_op_pad_push_constants_init(src0, dst);
9516 ggml_vk_op_f32(ctx, subctx, src0, src1: nullptr, src2: nullptr, src3: nullptr, dst, op: GGML_OP_PAD, pc: std::move(p));
9517}
9518
9519static void ggml_vk_roll(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
9520 const int32_t s0 = ggml_get_op_params_i32(tensor: dst, i: 0);
9521 const int32_t s1 = ggml_get_op_params_i32(tensor: dst, i: 1);
9522 const int32_t s2 = ggml_get_op_params_i32(tensor: dst, i: 2);
9523 const int32_t s3 = ggml_get_op_params_i32(tensor: dst, i: 3);
9524 const uint32_t s01_packed = ((s0 + 0x8000) << 16) | (s1 + 0x8000);
9525 const uint32_t s23_packed = ((s2 + 0x8000) << 16) | (s3 + 0x8000);
9526
9527 vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
9528 memcpy(dest: &p.param1, src: &s01_packed, n: sizeof(float));
9529 memcpy(dest: &p.param2, src: &s23_packed, n: sizeof(float));
9530
9531 ggml_vk_op_f32(ctx, subctx, src0, src1: nullptr, src2: nullptr, src3: nullptr, dst, op: GGML_OP_ROLL, pc: std::move(p));
9532}
9533
9534static void ggml_vk_repeat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
9535 vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ne: ggml_nelements(tensor: dst));
9536 ggml_vk_op_f32(ctx, subctx, src0, src1: nullptr, src2: nullptr, src3: nullptr, dst, op: GGML_OP_REPEAT, pc: std::move(p));
9537}
9538
9539static void ggml_vk_repeat_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
9540 vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ne: ggml_nelements(tensor: dst));
9541 ggml_vk_op_f32(ctx, subctx, src0, src1: nullptr, src2: nullptr, src3: nullptr, dst, op: GGML_OP_REPEAT_BACK, pc: std::move(p));
9542}
9543
9544static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
9545 uint32_t ne = (uint32_t)ggml_nelements(tensor: src0);
9546 if (ggml_is_quantized(type: src0->type) && ggml_is_quantized(type: dst->type)) {
9547 // Convert from number of logical elements to 2- or 4-byte units.
9548 ne /= ggml_blck_size(type: src0->type);
9549 if ((ggml_type_size(type: src0->type) % 4) == 0) {
9550 ne *= ggml_type_size(type: src0->type) / 4;
9551 } else {
9552 ne *= ggml_type_size(type: src0->type) / 2;
9553 }
9554 }
9555
9556 vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ne);
9557 ggml_vk_op_f32(ctx, subctx, src0, src1: nullptr, src2: nullptr, src3: nullptr, dst, op: GGML_OP_CPY, pc: std::move(p));
9558}
9559
9560static void ggml_vk_set_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
9561 const uint32_t src0_type_size = ggml_type_size(type: src0->type);
9562 const uint32_t src1_type_size = ggml_type_size(type: src1->type);
9563 const uint32_t dst_type_size = ggml_type_size(type: dst->type);
9564
9565 // Skip empty skip_rows operations. For most ops the empty check at the start
9566 // of ggml_vk_build_graph is sufficient, but set_rows can have a nonempty dst
9567 // with empty srcs.
9568 if (ggml_is_empty(tensor: src0) || ggml_is_empty(tensor: src1)) {
9569 return;
9570 }
9571
9572 ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, src2: nullptr, src3: nullptr, dst, op: GGML_OP_SET_ROWS, pc: {
9573 .ne: (uint32_t)ggml_nelements(tensor: src0),
9574 .ne00: (uint32_t)src0->ne[0], .ne01: (uint32_t)src0->ne[1], .ne02: (uint32_t)src0->ne[2],.ne03: (uint32_t)src0->ne[3], .nb00: (uint32_t)src0->nb[0] / src0_type_size, .nb01: (uint32_t)src0->nb[1] / src0_type_size, .nb02: (uint32_t)src0->nb[2] / src0_type_size, .nb03: (uint32_t)src0->nb[3] / src0_type_size,
9575 .ne10: (uint32_t)src1->ne[0], .ne11: (uint32_t)src1->ne[1], .ne12: (uint32_t)src1->ne[2],.ne13: (uint32_t)src1->ne[3], .nb10: (uint32_t)src1->nb[0] / src1_type_size, .nb11: (uint32_t)src1->nb[1] / src1_type_size, .nb12: (uint32_t)src1->nb[2] / src1_type_size, .nb13: (uint32_t)src1->nb[3] / src1_type_size,
9576 .ne20: (uint32_t) dst->ne[0], .ne21: (uint32_t) dst->ne[1], .ne22: (uint32_t) dst->ne[2],.ne23: (uint32_t) dst->ne[3], .nb20: (uint32_t) dst->nb[0] / dst_type_size, .nb21: (uint32_t) dst->nb[1] / dst_type_size, .nb22: (uint32_t) dst->nb[2] / dst_type_size, .nb23: (uint32_t) dst->nb[3] / dst_type_size,
9577 .misalign_offsets: 0,
9578 .param1: 0.0f, .param2: 0.0f, .param3: 0,
9579 });
9580}
9581
9582static void ggml_vk_silu_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
9583 ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, src2: nullptr, src3: nullptr, dst, op: GGML_OP_SILU_BACK, pc: { .KX: (uint32_t)ggml_nelements(tensor: src0), .KY: 0, .param1: 0.0f, .param2: 0.0f });
9584}
9585
9586static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
9587 float * op_params = (float *)dst->op_params;
9588
9589 ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1: nullptr, src2: nullptr, src3: nullptr, dst, op: GGML_OP_NORM, pc: { .KX: (uint32_t)src0->ne[0], .KY: (uint32_t)src0->ne[1], .param1: op_params[0], .param2: 0.0f });
9590}
9591
9592static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
9593 const int * int_op_params = (const int *)dst->op_params;
9594 const float * float_op_params = (const float *)dst->op_params;
9595
9596 const uint32_t num_groups = int_op_params[0];
9597 const float eps = float_op_params[1];
9598 const uint32_t group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
9599
9600 ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1: nullptr, src2: nullptr, src3: nullptr, dst, op: GGML_OP_GROUP_NORM, pc: { .KX: group_size, .KY: 0, .param1: eps, .param2: 0.0f });
9601}
9602
9603static uint32_t ggml_vk_rms_num_partials(ggml_backend_vk_context * ctx, const ggml_tensor *node) {
9604 const uint32_t ne = (uint32_t)node->ne[0];
9605 const uint32_t denom = ctx->device->pipeline_add_rms[0][0][0]->wg_denoms[0];
9606 const uint32_t num_partials = CEIL_DIV(ne, denom);
9607 return num_partials;
9608}
9609
9610static uint32_t ggml_vk_rms_partials_size(ggml_backend_vk_context * ctx, const ggml_tensor *node) {
9611 const uint32_t num_partials = ggml_vk_rms_num_partials(ctx, node);
9612 const uint32_t num_bytes = ROUNDUP_POW2(num_partials * sizeof(uint32_t), ctx->device->partials_binding_alignment);
9613 return num_bytes;
9614}
9615
9616static vk_op_rope_push_constants ggml_vk_make_rope_constants(const ggml_tensor *dst, const ggml_tensor *src0, const bool has_ff, bool backprop, const uint32_t set_rows_stride) {
9617 const int n_dims = ((const int32_t *) dst->op_params)[1];
9618 const int mode = ((const int32_t *) dst->op_params)[2];
9619 // const int n_ctx = ((const int32_t *) dst->op_params)[3];
9620 const int n_ctx_orig = ((const int32_t *) dst->op_params)[4];
9621 const float freq_base = ((const float *) dst->op_params)[5];
9622 const float freq_scale = ((const float *) dst->op_params)[6];
9623 const float ext_factor = ((const float *) dst->op_params)[7];
9624 const float attn_factor = ((const float *) dst->op_params)[8];
9625 const float beta_fast = ((const float *) dst->op_params)[9];
9626 const float beta_slow = ((const float *) dst->op_params)[10];
9627 int sections[4] {};
9628 if (mode & GGML_ROPE_TYPE_MROPE) {
9629 memcpy(dest: sections, src: (const int32_t *) dst->op_params + 11, n: sizeof(int)*4);
9630 }
9631
9632 const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
9633
9634 float corr_dims[2];
9635 ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, dims: corr_dims);
9636
9637 const float theta_scale = powf(x: freq_base, y: -2.0f/n_dims);
9638
9639 uint32_t nb01 = src0->nb[1] / ggml_type_size(type: src0->type);
9640 uint32_t nb02 = src0->nb[2] / ggml_type_size(type: src0->type);
9641
9642 vk_op_rope_push_constants rope {
9643 .rope_mode: (uint32_t)mode, .ncols: (uint32_t)src0->ne[0], .n_dims: (uint32_t)n_dims, .freq_scale: freq_scale, .p_delta_rows: (uint32_t)src0->ne[1],
9644 .freq_base: freq_base, .ext_factor: ext_factor, .attn_factor: attn_factor, .corr_dims: {corr_dims[0], corr_dims[1]}, .theta_scale: theta_scale,
9645 .has_ff: has_ff, .ne02: (uint32_t)src0->ne[2], .s1: nb01, .s2: nb02,
9646 .sections: { sections[0], sections[1], sections[2], sections[3] }, .is_imrope: is_imrope, .is_back: backprop, .set_rows_stride: set_rows_stride,
9647 };
9648
9649 return rope;
9650}
9651
9652static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx, float * op_params) {
9653 ggml_tensor * dst;
9654 const ggml_tensor * src0;
9655 const ggml_tensor * src1;
9656
9657 if (ctx->num_additional_fused_ops > 0) {
9658 // fused rms_norm + mul
9659 ggml_tensor *mul = cgraph->nodes[node_idx + 1];
9660 ggml_tensor *other_src = mul->src[0] == cgraph->nodes[node_idx + 0] ? mul->src[1] : mul->src[0];
9661 dst = mul;
9662 src0 = cgraph->nodes[node_idx]->src[0];
9663 src1 = other_src;
9664 } else {
9665 dst = cgraph->nodes[node_idx];
9666 src0 = src1 = dst->src[0];
9667 }
9668
9669 const uint32_t src0_type_size = ggml_type_size(type: src0->type);
9670 const uint32_t src1_type_size = ggml_type_size(type: src1->type);
9671 const uint32_t dst_type_size = ggml_type_size(type: dst->type);
9672
9673 uint32_t param3 = ctx->do_add_rms_partials ? ggml_vk_rms_num_partials(ctx, node: dst) : 0;
9674
9675 vk_op_binary_push_constants bin {
9676 .ne: (uint32_t)ggml_nelements(tensor: src0),
9677 .ne00: (uint32_t)src0->ne[0], .ne01: (uint32_t)src0->ne[1], .ne02: (uint32_t)src0->ne[2],.ne03: (uint32_t)src0->ne[3], .nb00: (uint32_t)src0->nb[0] / src0_type_size, .nb01: (uint32_t)src0->nb[1] / src0_type_size, .nb02: (uint32_t)src0->nb[2] / src0_type_size, .nb03: (uint32_t)src0->nb[3] / src0_type_size,
9678 .ne10: (uint32_t)src1->ne[0], .ne11: (uint32_t)src1->ne[1], .ne12: (uint32_t)src1->ne[2],.ne13: (uint32_t)src1->ne[3], .nb10: (uint32_t)src1->nb[0] / src1_type_size, .nb11: (uint32_t)src1->nb[1] / src1_type_size, .nb12: (uint32_t)src1->nb[2] / src1_type_size, .nb13: (uint32_t)src1->nb[3] / src1_type_size,
9679 .ne20: (uint32_t) dst->ne[0], .ne21: (uint32_t) dst->ne[1], .ne22: (uint32_t) dst->ne[2],.ne23: (uint32_t) dst->ne[3], .nb20: (uint32_t) dst->nb[0] / dst_type_size, .nb21: (uint32_t) dst->nb[1] / dst_type_size, .nb22: (uint32_t) dst->nb[2] / dst_type_size, .nb23: (uint32_t) dst->nb[3] / dst_type_size,
9680 .misalign_offsets: 0,
9681 .param1: op_params[0], .param2: 0.0f, .param3: (int32_t)param3,
9682 };
9683
9684 // more than one fused op means rms_norm+mul+rope
9685 if (ctx->num_additional_fused_ops > 1) {
9686 static constexpr uint32_t max_tensors = 7;
9687 const ggml_tensor *tensors[max_tensors] {};
9688
9689 ggml_tensor *rms = cgraph->nodes[node_idx + 0];
9690 ggml_tensor *mul = cgraph->nodes[node_idx + 1];
9691 ggml_tensor *rope = cgraph->nodes[node_idx + 2];
9692
9693 ggml_tensor *other_src = mul->src[0] == rms ? mul->src[1] : mul->src[0];
9694
9695 bool do_set_rows = ctx->num_additional_fused_ops == 4;
9696
9697 tensors[0] = rms->src[0];
9698 tensors[1] = other_src;
9699 tensors[2] = mul;
9700 tensors[3] = rope->src[1]; // pos
9701 tensors[4] = rope->src[2]; // ff
9702 tensors[5] = cgraph->nodes[node_idx + ctx->num_additional_fused_ops]; // dst
9703 tensors[6] = do_set_rows ? tensors[5]->src[1] : nullptr;
9704 const uint32_t set_rows_stride = do_set_rows ? tensors[5]->nb[1] / ggml_type_size(type: tensors[5]->type) : 0;
9705
9706 vk_op_rms_norm_mul_rope_push_constants pc;
9707 pc.bin = bin;
9708 pc.rope = ggml_vk_make_rope_constants(dst: rope, src0: rope->src[0], has_ff: tensors[4] != nullptr, backprop: false, set_rows_stride);
9709
9710 vk_pipeline pipeline = tensors[5]->type == GGML_TYPE_F16 ? ctx->device->pipeline_rms_norm_mul_rope_f32_f16 : ctx->device->pipeline_rms_norm_mul_rope_f32_f32;
9711
9712 ggml_pipeline_request_descriptor_sets(ctx, pipeline, n: 1);
9713
9714 ggml_backend_vk_buffer_context * buf_ctx[max_tensors];
9715 vk_buffer buf[max_tensors];
9716 size_t offset[max_tensors];
9717 bool uma[max_tensors];
9718
9719 for (uint32_t i = 0; i < max_tensors; ++i) {
9720 if (!tensors[i]) {
9721 // If any remaining descriptors are unused, just point them at src[0]
9722 buf[i] = buf[0];
9723 offset[i] = 0;
9724 continue;
9725 }
9726 buf_ctx[i] = (ggml_backend_vk_buffer_context *)tensors[i]->buffer->context;
9727 buf[i] = nullptr;
9728 offset[i] = 0;
9729 uma[i] = false;
9730
9731 if (ctx->device->uma) {
9732 ggml_vk_host_get(device: ctx->device, ptr: tensors[i]->data, buf&: buf[i], buf_offset&: offset[i]);
9733 uma[i] = buf[i] != nullptr;
9734 }
9735 if (!uma[i]) {
9736 buf[i] = buf_ctx[i]->dev_buffer;
9737 offset[i] = vk_tensor_offset(tensor: tensors[i]) + tensors[i]->view_offs;
9738 }
9739 GGML_ASSERT(buf[i] != nullptr);
9740 }
9741
9742 std::array<uint32_t, 3> elements;
9743 elements = { (uint32_t)rms->src[0]->ne[1], (uint32_t)rms->src[0]->ne[2], (uint32_t)rms->src[0]->ne[3] };
9744
9745 static_assert(max_tensors == 7);
9746 ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
9747 descriptor_buffer_infos: {
9748 ggml_vk_subbuffer(ctx, buf: buf[0], offset: offset[0]),
9749 ggml_vk_subbuffer(ctx, buf: buf[1], offset: offset[1]),
9750 ggml_vk_subbuffer(ctx, buf: buf[2], offset: offset[2]),
9751 ggml_vk_subbuffer(ctx, buf: buf[3], offset: offset[3]),
9752 ggml_vk_subbuffer(ctx, buf: buf[4], offset: offset[4]),
9753 ggml_vk_subbuffer(ctx, buf: buf[5], offset: offset[5]),
9754 ggml_vk_subbuffer(ctx, buf: buf[6], offset: offset[6]),
9755 }, push_constants: pc, elements);
9756 } else {
9757 ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, src2: nullptr, src3: nullptr, dst, op: GGML_OP_RMS_NORM, pc: std::move(bin));
9758 }
9759
9760 if (ctx->do_add_rms_partials_offset_calculation) {
9761 ctx->prealloc_size_add_rms_partials_offset += ggml_vk_rms_partials_size(ctx, node: src0);
9762 ctx->do_add_rms_partials = false;
9763 ctx->do_add_rms_partials_offset_calculation = false;
9764 }
9765}
9766
9767static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
9768 float * op_params = (float *)dst->op_params;
9769 ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, src2: nullptr, src3: nullptr, dst, op: GGML_OP_RMS_NORM_BACK, pc: { .KX: (uint32_t)src0->ne[0], .KY: (uint32_t)src0->ne[1], .param1: op_params[0], .param2: 0.0f });
9770}
9771
9772static void ggml_vk_l2_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
9773 float * op_params = (float *)dst->op_params;
9774 ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1: nullptr, src2: nullptr, src3: nullptr, dst, op: GGML_OP_L2_NORM, pc: { .KX: (uint32_t)src0->ne[0], .KY: (uint32_t)src0->ne[1], .param1: op_params[0], .param2: 0.0f });
9775}
9776
9777static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
9778 ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1: nullptr, src2: nullptr, src3: nullptr, dst, op: GGML_OP_UNARY, pc: { .KX: (uint32_t)ggml_nelements(tensor: src0), .KY: 0, .param1: 0.0f, .param2: 0.0f });
9779}
9780
9781static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
9782 const float * op_params_f = (const float *)dst->op_params;
9783
9784 const bool swapped = (bool)dst->op_params[1];
9785 const bool split = src1 != nullptr;
9786 const float alpha = op_params_f[2];
9787 const float limit = op_params_f[3];
9788
9789 GGML_ASSERT(ggml_is_contiguous(src0));
9790
9791 if (!split) {
9792 GGML_ASSERT(src0->ne[0] / 2 == dst->ne[0]);
9793 } else {
9794 GGML_ASSERT(src0->ne[0] == src1->ne[0]);
9795 GGML_ASSERT(src0->ne[0] == dst->ne[0]);
9796 GGML_ASSERT(src0->type == src1->type);
9797 }
9798
9799 const uint32_t mode = split ? 2 : (swapped ? 1 : 0);
9800
9801 ggml_vk_op_f32<vk_op_glu_push_constants>(ctx, subctx, src0, src1, src2: nullptr, src3: nullptr, dst, op: GGML_OP_GLU,
9802 pc: {
9803 .N: (uint32_t)ggml_nelements(tensor: dst),
9804 .ne00: (uint32_t)src0->ne[0],
9805 .ne20: (uint32_t)dst->ne[0],
9806 .mode: mode,
9807 .alpha: alpha,
9808 .limit: limit
9809 });
9810}
9811
9812static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
9813 int32_t * op_params = (int32_t *)dst->op_params;
9814 ggml_vk_op_f32<vk_op_diag_mask_push_constants>(ctx, subctx, src0, src1: nullptr, src2: nullptr, src3: nullptr, dst, op: GGML_OP_DIAG_MASK_INF, pc: { .ncols: (uint32_t)src0->ne[0], .rows_per_channel: (uint32_t)src0->ne[1], .n_past: op_params[0] });
9815}
9816
9817static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
9818 float * op_params = (float *)dst->op_params;
9819
9820 float scale = op_params[0];
9821 float max_bias = op_params[1];
9822
9823 const uint32_t ncols = (uint32_t)src0->ne[0];
9824 const uint32_t nrows_x = (uint32_t)ggml_nrows(tensor: src0);
9825 const uint32_t nrows_y = (uint32_t)src0->ne[1];
9826
9827 const uint32_t ne12 = src1 ? (uint32_t)(src1->ne[2]) : 0u;
9828 const uint32_t ne13 = src1 ? (uint32_t)(src1->ne[3]) : 0u;
9829 const uint32_t nb11 = src1 ? (uint32_t)(src1->nb[1] / src1->nb[0]) : 0u;
9830 const uint32_t nb12 = src1 ? (uint32_t)(src1->nb[2] / src1->nb[0]) : 0u;
9831 const uint32_t nb13 = src1 ? (uint32_t)(src1->nb[3] / src1->nb[0]) : 0u;
9832
9833 const uint32_t n_head_kv = src0->ne[2];
9834 const uint32_t n_head_log2 = 1u << (uint32_t) floorf(x: log2f(x: (float) n_head_kv));
9835
9836 const float m0 = powf(x: 2.0f, y: -(max_bias ) / n_head_log2);
9837 const float m1 = powf(x: 2.0f, y: -(max_bias / 2.0f) / n_head_log2);
9838
9839 ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, src2, src3: nullptr, dst, op: GGML_OP_SOFT_MAX, pc: {
9840 .KX: ncols,
9841 .KY: src1 != nullptr ? nrows_y : (uint32_t)0,
9842 .ne00: (uint32_t)src0->ne[0], .ne01: (uint32_t)src0->ne[1], .ne02: (uint32_t)src0->ne[2],
9843 .ne12: ne12, .ne13: ne13,
9844 .nb11: nb11, .nb12: nb12, .nb13: nb13,
9845 .scale: scale, .max_bias: max_bias,
9846 .m0: m0, .m1: m1,
9847 .n_head_log2: n_head_log2,
9848 .nrows_x: nrows_x,
9849 .has_sinks: src2 != nullptr
9850 });
9851}
9852
9853static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
9854 float * op_params = (float *)dst->op_params;
9855 ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, src2: nullptr, src3: nullptr, dst, op: GGML_OP_SOFT_MAX_BACK, pc: { .KX: (uint32_t)src0->ne[0], .KY: (uint32_t)ggml_nrows(tensor: src0), .param1: op_params[0], .param2: op_params[1] });
9856}
9857
9858static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx) {
9859 topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(num: ctx->num_additional_fused_ops);
9860 ggml_tensor * logits = cgraph->nodes[node_idx + 0]->src[0];
9861 ggml_tensor * weights = (mode == TOPK_MOE_EARLY_SOFTMAX_NORM) ? cgraph->nodes[node_idx + 9] :
9862 (mode == TOPK_MOE_EARLY_SOFTMAX) ? cgraph->nodes[node_idx + 4] :
9863 cgraph->nodes[node_idx + 5];
9864 ggml_tensor * ids = (mode == TOPK_MOE_LATE_SOFTMAX) ? cgraph->nodes[node_idx + 1] : cgraph->nodes[node_idx + 3];
9865
9866 GGML_ASSERT(logits->type == GGML_TYPE_F32);
9867 GGML_ASSERT(weights->type == GGML_TYPE_F32);
9868 GGML_ASSERT(ids->type == GGML_TYPE_I32);
9869
9870 const int n_experts = logits->ne[0];
9871 const int n_rows = logits->ne[1];
9872 const int n_expert_used = weights->ne[1];
9873
9874 GGML_ASSERT(ids->nb[1] / ggml_type_size(ids->type) == (size_t) n_experts);
9875
9876 vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0: nullptr, src1: nullptr, src2: nullptr, dst: cgraph->nodes[node_idx], op: GGML_OP_SOFT_MAX);
9877
9878 ggml_pipeline_request_descriptor_sets(ctx, pipeline, n: 1);
9879
9880 vk_subbuffer logits_buf = ggml_vk_tensor_subbuffer(ctx, tensor: logits);
9881 vk_subbuffer weights_buf = ggml_vk_tensor_subbuffer(ctx, tensor: weights);
9882 vk_subbuffer ids_buf = ggml_vk_tensor_subbuffer(ctx, tensor: ids);
9883
9884 vk_op_topk_moe_push_constants pc {};
9885 pc.n_rows = n_rows;
9886 pc.n_expert_used = n_expert_used;
9887 if (mode == TOPK_MOE_EARLY_SOFTMAX_NORM) {
9888 ggml_tensor * clamp = cgraph->nodes[node_idx + 7];
9889 pc.clamp_min = ggml_get_op_params_f32(tensor: clamp, i: 0);
9890 pc.clamp_max = ggml_get_op_params_f32(tensor: clamp, i: 1);
9891 }
9892
9893 GGML_ASSERT(n_expert_used <= n_experts);
9894
9895 const uint32_t rows_per_block = 4;
9896 std::array<uint32_t, 3> elements = { CEIL_DIV(n_rows, rows_per_block), 1, 1 };
9897
9898 ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, descriptor_buffer_infos: {logits_buf, weights_buf, ids_buf}, push_constants: pc, elements);
9899}
9900
9901static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_cgraph * cgraph, int node_idx, bool backprop) {
9902 ggml_tensor * dst = cgraph->nodes[node_idx];
9903 const ggml_tensor * src0 = dst->src[0];
9904 const ggml_tensor * src1 = dst->src[1];
9905 const ggml_tensor * src2 = dst->src[2];
9906 const ggml_tensor * src3 = nullptr;
9907 const int n_dims = ((int32_t *) dst->op_params)[1];
9908 const int mode = ((int32_t *) dst->op_params)[2];
9909 // const int n_ctx = ((int32_t *) dst->op_params)[3];
9910 const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
9911 const float freq_base = ((float *) dst->op_params)[5];
9912 const float beta_fast = ((float *) dst->op_params)[9];
9913 const float beta_slow = ((float *) dst->op_params)[10];
9914 int sections[4] {};
9915 if (mode & GGML_ROPE_TYPE_MROPE) {
9916 memcpy(dest: sections, src: (int32_t *) dst->op_params + 11, n: sizeof(int)*4);
9917 }
9918
9919 float corr_dims[2];
9920 ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, dims: corr_dims);
9921
9922 uint32_t set_rows_stride = 0;
9923 // Fused rope + view + set_rows passes the set_rows destination stride in set_rows_stride
9924 // and overrides the dst and sets src3=row_indices
9925 if (ctx->num_additional_fused_ops > 0) {
9926 set_rows_stride = cgraph->nodes[node_idx + 2]->nb[1] / ggml_type_size(type: cgraph->nodes[node_idx + 2]->type);
9927 src3 = cgraph->nodes[node_idx + 2]->src[1];
9928 dst = cgraph->nodes[node_idx + 2];
9929 }
9930
9931 ggml_vk_op_f32<vk_op_rope_push_constants>(ctx, subctx, src0, src1, src2, src3, dst, op: GGML_OP_ROPE,
9932 pc: ggml_vk_make_rope_constants(dst: cgraph->nodes[node_idx], src0, has_ff: src2 != nullptr, backprop, set_rows_stride));
9933}
9934
9935static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
9936 int32_t * op_params = (int32_t *)dst->op_params;
9937
9938 uint32_t ncols = src0->ne[0];
9939 uint32_t nrows = ggml_nrows(tensor: src0);
9940
9941 ggml_vk_op_f32<vk_op_argsort_push_constants>(ctx, subctx, src0, src1: nullptr, src2: nullptr, src3: nullptr, dst, op: GGML_OP_ARGSORT, pc: {
9942 .ncols: ncols,
9943 .nrows: nrows,
9944 .order: op_params[0],
9945 });
9946}
9947
9948static void ggml_vk_sum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
9949 vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src: src0, dst, n_cols: ggml_nelements(tensor: src0));
9950 ggml_vk_op_f32(ctx, subctx, src0, src1: nullptr, src2: nullptr, src3: nullptr, dst, op: GGML_OP_SUM, pc&: p);
9951}
9952
9953static void ggml_vk_sum_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
9954 vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src: src0, dst, n_cols: src0->ne[0]);
9955 ggml_vk_op_f32(ctx, subctx, src0, src1: nullptr, src2: nullptr, src3: nullptr, dst, op: GGML_OP_SUM_ROWS, pc&: p);
9956}
9957
9958static void ggml_vk_mean(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
9959 vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src: src0, dst, n_cols: src0->ne[0]);
9960 p.weight = 1.0f / (float)src0->ne[0];
9961 ggml_vk_op_f32(ctx, subctx, src0, src1: nullptr, src2: nullptr, src3: nullptr, dst, op: GGML_OP_MEAN, pc&: p);
9962}
9963
9964static void ggml_vk_argmax(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
9965 ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1: nullptr, src2: nullptr, src3: nullptr, dst, op: GGML_OP_ARGMAX, pc: { .KX: (uint32_t)src0->ne[0], .KY: (uint32_t)src0->ne[1], .param1: 0.0f, .param2: 0.0f });
9966}
9967
9968static void ggml_vk_count_equal(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
9969 ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, src2: nullptr, src3: nullptr, dst, op: GGML_OP_COUNT_EQUAL, pc: { .KX: (uint32_t)ggml_nelements(tensor: src0), .KY: 0, .param1: 0.0f, .param2: 0.0f });
9970}
9971
9972static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
9973 const int32_t s0 = dst->op_params[0];
9974 const int32_t s1 = dst->op_params[1];
9975 const int32_t p0 = dst->op_params[2];
9976 const int32_t p1 = dst->op_params[3];
9977 const int32_t d0 = dst->op_params[4];
9978 const int32_t d1 = dst->op_params[5];
9979
9980 const bool is_2D = dst->op_params[6] == 1;
9981
9982 const uint32_t IC = src1->ne[is_2D ? 2 : 1];
9983 const uint32_t IH = is_2D ? src1->ne[1] : 1;
9984 const uint32_t IW = src1->ne[0];
9985
9986 const uint32_t KH = is_2D ? src0->ne[1] : 1;
9987 const uint32_t KW = src0->ne[0];
9988
9989 const uint32_t OH = is_2D ? dst->ne[2] : 1;
9990 const uint32_t OW = dst->ne[1];
9991
9992 const uint32_t offset_delta = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
9993 const uint32_t batch_offset = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32
9994
9995 const uint32_t pelements = OW * KW * KH;
9996
9997 const ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
9998 const vk_buffer d_buf = d_buf_ctx->dev_buffer;
9999
10000 const vk::DeviceAddress dst_addr = d_buf->bda_addr + vk_tensor_offset(tensor: dst) + dst->view_offs;
10001
10002 ggml_vk_op_f32<vk_op_im2col_push_constants>(ctx, subctx, src0, src1, src2: nullptr, src3: nullptr, dst, op: GGML_OP_IM2COL, pc: {
10003 .dst_addr: dst_addr,
10004 .batch_offset: batch_offset, .offset_delta: offset_delta,
10005 .IC: IC, .IW: IW, .IH: IH, .OW: OW, .OH: OH, .KW: KW, .KH: KH,
10006 .pelements: pelements,
10007 .CHW: IC * KH * KW,
10008 .s0: s0, .s1: s1, .p0: p0, .p1: p1, .d0: d0, .d1: d1,
10009 });
10010}
10011
10012static void ggml_vk_im2col_3d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
10013 GGML_TENSOR_BINARY_OP_LOCALS
10014
10015 const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
10016 const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
10017 const int32_t s2 = ((const int32_t *)(dst->op_params))[2];
10018 const int32_t p0 = ((const int32_t *)(dst->op_params))[3];
10019 const int32_t p1 = ((const int32_t *)(dst->op_params))[4];
10020 const int32_t p2 = ((const int32_t *)(dst->op_params))[5];
10021 const int32_t d0 = ((const int32_t *)(dst->op_params))[6];
10022 const int32_t d1 = ((const int32_t *)(dst->op_params))[7];
10023 const int32_t d2 = ((const int32_t *)(dst->op_params))[8];
10024 const int32_t IC = ((const int32_t *)(dst->op_params))[9];
10025
10026 const int64_t N = ne13 / IC;
10027 const int64_t ID = ne12;
10028 const int64_t IH = ne11;
10029 const int64_t IW = ne10;
10030
10031 const int64_t KD = ne02;
10032 const int64_t KH = ne01;
10033 const int64_t KW = ne00;
10034
10035 const int64_t OD = ne3 / N;
10036 const int64_t OH = ne2;
10037 const int64_t OW = ne1;
10038
10039 const ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
10040 const vk_buffer d_buf = d_buf_ctx->dev_buffer;
10041
10042 const vk::DeviceAddress dst_addr = d_buf->bda_addr + vk_tensor_offset(tensor: dst) + dst->view_offs;
10043
10044 vk_op_im2col_3d_push_constants pc {};
10045
10046 pc.dst_addr = dst_addr;
10047 pc.nb10 = nb10 / ggml_type_size(type: src1->type);
10048 pc.nb11 = nb11 / ggml_type_size(type: src1->type);
10049 pc.nb12 = nb12 / ggml_type_size(type: src1->type);
10050 pc.nb13 = nb13 / ggml_type_size(type: src1->type);
10051 pc.s0 = s0;
10052 pc.s1 = s1;
10053 pc.s2 = s2;
10054 pc.p0 = p0;
10055 pc.p1 = p1;
10056 pc.p2 = p2;
10057 pc.d0 = d0;
10058 pc.d1 = d1;
10059 pc.d2 = d2;
10060 pc.IW = IW;
10061 pc.IH = IH;
10062 pc.ID = ID;
10063 pc.IC = IC;
10064 pc.KW = KW;
10065 pc.OH = OH;
10066 pc.KD_KH_KW = KD*KH*KW;
10067 pc.KH_KW = KH*KW;
10068 pc.IC_KD_KH_KW = IC*KD*KH*KW;
10069 pc.N_OD_OH = N*OD*OH;
10070 pc.OD_OH = OD*OH;
10071 pc.OD_OH_OW_IC_KD_KH_KW = OD*OH*OW*IC*KD*KH*KW;
10072 pc.OH_OW_IC_KD_KH_KW = OH*OW*IC*KD*KH*KW;
10073 pc.OW_IC_KD_KH_KW = OW*IC*KD*KH*KW;
10074
10075 ggml_vk_op_f32<vk_op_im2col_3d_push_constants>(ctx, subctx, src0, src1, src2: nullptr, src3: nullptr, dst, op: GGML_OP_IM2COL_3D, pc: std::move(pc));
10076}
10077
10078static void ggml_vk_timestep_embedding(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
10079 const uint32_t dim = dst->op_params[0];
10080 const uint32_t max_period = dst->op_params[1];
10081 const uint32_t nb1 = dst->nb[1] / ggml_type_size(type: dst->type);
10082
10083 ggml_vk_op_f32<vk_op_timestep_embedding_push_constants>(ctx, subctx, src0, src1: nullptr, src2: nullptr, src3: nullptr, dst, op: GGML_OP_TIMESTEP_EMBEDDING, pc: {
10084 .nb1: nb1, .dim: dim, .max_period: max_period,
10085 });
10086}
10087
10088static void ggml_vk_conv_transpose_1d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
10089 // src0: (K, Cout, Cin, 1) -- kernel
10090 // src1: (L, Cin, 1, 1) -- input
10091 // dst: (*, Cout, 1, 1)
10092
10093 GGML_ASSERT(src0->type == GGML_TYPE_F32);
10094 GGML_ASSERT(src1->type == GGML_TYPE_F32);
10095 GGML_ASSERT( dst->type == GGML_TYPE_F32);
10096
10097 GGML_TENSOR_BINARY_OP_LOCALS
10098
10099 GGML_ASSERT(nb00 == sizeof(float));
10100 GGML_ASSERT(nb10 == sizeof(float));
10101
10102 const int32_t s0 = dst->op_params[0];
10103
10104 vk_op_conv_transpose_1d_push_constants p{};
10105 p.Cout = static_cast<uint32_t>(ne01);
10106 p.Cin = static_cast<uint32_t>(ne02);
10107 p.K = static_cast<uint32_t>(ne00);
10108 p.L = static_cast<uint32_t>(ne10);
10109 p.KL = static_cast<uint32_t>(ne0);
10110 p.nb01 = static_cast<uint32_t>(nb01 / nb00);
10111 p.nb02 = static_cast<uint32_t>(nb02 / nb00);
10112 p.nb11 = static_cast<uint32_t>(nb11 / nb10);
10113 p.nb1 = static_cast<uint32_t>(nb1 / nb0);
10114 p.s0 = static_cast<uint32_t>(s0);
10115
10116 ggml_vk_op_f32(ctx, subctx, src0, src1, src2: nullptr, src3: nullptr, dst, op: GGML_OP_CONV_TRANSPOSE_1D, pc: std::move(p));
10117}
10118
10119static void ggml_vk_pool_2d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
10120 uint32_t op = static_cast<uint32_t>(dst->op_params[0]);
10121 const int32_t k1 = dst->op_params[1];
10122 const int32_t k0 = dst->op_params[2];
10123 const int32_t s1 = dst->op_params[3];
10124 const int32_t s0 = dst->op_params[4];
10125 const int32_t p1 = dst->op_params[5];
10126 const int32_t p0 = dst->op_params[6];
10127
10128 const uint32_t IH = src0->ne[1];
10129 const uint32_t IW = src0->ne[0];
10130
10131 const uint32_t N = dst->ne[3];
10132
10133 const uint32_t OC = dst->ne[2];
10134 const uint32_t OH = dst->ne[1];
10135 const uint32_t OW = dst->ne[0];
10136
10137 const uint32_t parallel_elements = N * OC * OH * OW;
10138
10139 ggml_vk_op_f32<vk_op_pool2d_push_constants>(ctx, subctx, src0, src1: nullptr, src2: nullptr, src3: nullptr, dst, op: GGML_OP_POOL_2D, pc: {
10140 .IW: IW, .IH: IH, .OW: OW, .OH: OH, .OC: OC,
10141 .pelements: parallel_elements,
10142 .op: op,
10143 .k0: k0, .k1: k1, .s0: s0, .s1: s1, .p0: p0, .p1: p1,
10144 });
10145}
10146
10147static void ggml_vk_conv_2d(ggml_backend_vk_context * ctx, vk_context & subctx, const ggml_tensor * src0,
10148 const ggml_tensor * src1, ggml_tensor * dst) {
10149 GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
10150 GGML_ASSERT(src1->type == GGML_TYPE_F32);
10151 GGML_ASSERT(dst->type == GGML_TYPE_F32);
10152
10153 GGML_TENSOR_BINARY_OP_LOCALS
10154
10155 GGML_ASSERT(nb00 == sizeof(float) || nb00 == sizeof(ggml_fp16_t));
10156 GGML_ASSERT(nb10 == sizeof(float));
10157 GGML_ASSERT(nb0 == sizeof(float));
10158
10159 vk_op_conv2d_push_constants p{};
10160 p.Cout = static_cast<uint32_t>(ne03);
10161 p.Cin = static_cast<uint32_t>(ne02);
10162 p.N = static_cast<uint32_t>(ne13);
10163
10164 p.KW = static_cast<uint32_t>(ne00);
10165 p.KH = static_cast<uint32_t>(ne01);
10166 p.W = static_cast<uint32_t>(ne10);
10167 p.H = static_cast<uint32_t>(ne11);
10168 p.OW = static_cast<uint32_t>(ne0);
10169 p.OH = static_cast<uint32_t>(ne1);
10170
10171 p.s0 = static_cast<uint32_t>(dst->op_params[0]);
10172 p.s1 = static_cast<uint32_t>(dst->op_params[1]);
10173 p.p0 = static_cast<uint32_t>(dst->op_params[2]);
10174 p.p1 = static_cast<uint32_t>(dst->op_params[3]);
10175 p.d0 = static_cast<uint32_t>(dst->op_params[4]);
10176 p.d1 = static_cast<uint32_t>(dst->op_params[5]);
10177
10178 p.nb01 = static_cast<uint32_t>(nb01 / nb00);
10179 p.nb02 = static_cast<uint32_t>(nb02 / nb00);
10180 p.nb03 = static_cast<uint32_t>(nb03 / nb00);
10181
10182 p.nb11 = static_cast<uint32_t>(nb11 / nb10);
10183 p.nb12 = static_cast<uint32_t>(nb12 / nb10);
10184 p.nb13 = static_cast<uint32_t>(nb13 / nb10);
10185
10186 p.nb1 = static_cast<uint32_t>(nb1 / nb0);
10187 p.nb2 = static_cast<uint32_t>(nb2 / nb0);
10188 p.nb3 = static_cast<uint32_t>(nb3 / nb0);
10189
10190 GGML_ASSERT(ne03 == ne2);
10191 GGML_ASSERT(ne02 == ne12);
10192
10193 ggml_vk_op_f32(ctx, subctx, src0, src1, src2: nullptr, src3: nullptr, dst, op: GGML_OP_CONV_2D, pc: std::move(p));
10194}
10195
10196static void ggml_vk_conv_transpose_2d(ggml_backend_vk_context * ctx, vk_context & subctx, const ggml_tensor * src0,
10197 const ggml_tensor * src1, ggml_tensor * dst) {
10198 GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
10199 GGML_ASSERT(src1->type == GGML_TYPE_F32);
10200 GGML_ASSERT(dst->type == GGML_TYPE_F32);
10201
10202 GGML_TENSOR_BINARY_OP_LOCALS
10203
10204 GGML_ASSERT(nb00 == sizeof(float) || nb00 == sizeof(ggml_fp16_t));
10205 GGML_ASSERT(nb10 == sizeof(float));
10206 GGML_ASSERT(nb0 == sizeof(float));
10207
10208 vk_op_conv_transpose_2d_push_constants p{};
10209 p.Cout = static_cast<uint32_t>(ne02);
10210 p.Cin = static_cast<uint32_t>(ne03);
10211 p.N = static_cast<uint32_t>(ne13);
10212
10213 p.KW = static_cast<uint32_t>(ne00);
10214 p.KH = static_cast<uint32_t>(ne01);
10215 p.W = static_cast<uint32_t>(ne10);
10216 p.H = static_cast<uint32_t>(ne11);
10217 p.OW = static_cast<uint32_t>(ne0);
10218 p.OH = static_cast<uint32_t>(ne1);
10219
10220 p.s0 = static_cast<uint32_t>(dst->op_params[0]);
10221 p.s1 = static_cast<uint32_t>(dst->op_params[0]);
10222 p.p0 = 0;
10223 p.p1 = 0;
10224 p.d0 = 1;
10225 p.d1 = 1;
10226
10227 p.nb01 = static_cast<uint32_t>(nb01 / nb00);
10228 p.nb02 = static_cast<uint32_t>(nb02 / nb00);
10229 p.nb03 = static_cast<uint32_t>(nb03 / nb00);
10230
10231 p.nb11 = static_cast<uint32_t>(nb11 / nb10);
10232 p.nb12 = static_cast<uint32_t>(nb12 / nb10);
10233 p.nb13 = static_cast<uint32_t>(nb13 / nb10);
10234
10235 p.nb1 = static_cast<uint32_t>(nb1 / nb0);
10236 p.nb2 = static_cast<uint32_t>(nb2 / nb0);
10237 p.nb3 = static_cast<uint32_t>(nb3 / nb0);
10238
10239 GGML_ASSERT(ne02 == ne2);
10240 GGML_ASSERT(ne03 == ne12);
10241
10242 ggml_vk_op_f32(ctx, subctx, src0, src1, src2: nullptr, src3: nullptr, dst, op: GGML_OP_CONV_TRANSPOSE_2D, pc: std::move(p));
10243}
10244
10245static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
10246 vk_op_conv2d_dw_push_constants p{};
10247 p.ne = ggml_nelements(tensor: dst);
10248 p.channels = dst->ne[2];
10249 p.batches = dst->ne[3];
10250 p.dst_w = dst->ne[0];
10251 p.dst_h = dst->ne[1];
10252 p.src_w = src1->ne[0];
10253 p.src_h = src1->ne[1];
10254 p.knl_w = src0->ne[0];
10255 p.knl_h = src0->ne[1];
10256 p.stride_x = dst->op_params[0];
10257 p.stride_y = dst->op_params[1];
10258 p.pad_x = dst->op_params[2];
10259 p.pad_y = dst->op_params[3];
10260 p.dilation_x = dst->op_params[4];
10261 p.dilation_y = dst->op_params[5];
10262
10263 GGML_ASSERT(src0->ne[3] == p.channels);
10264 GGML_ASSERT(src1->ne[3] == p.batches);
10265
10266 ggml_vk_op_f32(ctx, subctx, src0, src1, src2: nullptr, src3: nullptr, dst, op: GGML_OP_CONV_2D_DW, pc: std::move(p));
10267}
10268
10269static void ggml_vk_leaky_relu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
10270 const float * op_params = (const float *)dst->op_params;
10271 ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1: nullptr, src2: nullptr, src3: nullptr, dst, op: GGML_OP_LEAKY_RELU, pc: { .KX: (uint32_t)ggml_nelements(tensor: src0), .KY: 0, .param1: op_params[0], .param2: 0.0f });
10272}
10273
10274#ifdef GGML_VULKAN_RUN_TESTS
10275static void ggml_vk_print_matrix_area(const void * data, ggml_type type, int ne0, int ne1, int i0, int i1, int i2) {
10276 if (type != GGML_TYPE_F32 && type != GGML_TYPE_F16) {
10277 return;
10278 }
10279 i0 = std::max(i0, 5);
10280 i1 = std::max(i1, 5);
10281 i2 = std::max(i2, 0);
10282 fprintf(stderr, " ");
10283 for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) {
10284 fprintf(stderr, "%7d ", idx1);
10285 }
10286 fprintf(stderr, "\n");
10287 for (int idx0 = i0 - 5; idx0 < i0 + 5; idx0++) {
10288 fprintf(stderr, "%7d: ", idx0);
10289 for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) {
10290 if (idx0 >= 0 && idx0 < ne0 && idx1 >= 0 && idx1 < ne1) {
10291 float val;
10292 if (type == GGML_TYPE_F32) {
10293 val = *((const float *) data + i2*ne1*ne0 + idx1*ne0 + idx0);
10294 } else if (type == GGML_TYPE_F16) {
10295 val = ggml_fp16_to_fp32(*((const ggml_fp16_t *) data + i2*ne1*ne0 + idx1*ne0 + idx0));
10296 } else {
10297 GGML_ABORT("fatal error");
10298 }
10299 fprintf(stderr, "% 7.2f ", val);
10300 } else {
10301 fprintf(stderr, " ");
10302 }
10303 }
10304 fprintf(stderr, "\n");
10305 }
10306}
10307
10308template <typename X_TYPE, typename Y_TYPE>
10309static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t n, size_t k, size_t batch, size_t num_it, int split_k, int shader_size) {
10310 VK_LOG_DEBUG("ggml_vk_test_matmul(" << m << ", " << n << ", " << k << ", " << batch << ", " << num_it << ", " << split_k << ", " << shader_size << ")");
10311 const size_t x_ne = m * k * batch;
10312 const size_t y_ne = k * n * batch;
10313 const size_t d_ne = m * n * batch;
10314
10315 vk_pipeline p;
10316 std::string shname;
10317 if (shader_size == 0) {
10318 if (std::is_same<float, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
10319 p = ctx->device->pipeline_matmul_f32->a_s;
10320 shname = "F32_ALIGNED_S";
10321 } else if (std::is_same<float, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
10322 p = ctx->device->pipeline_matmul_f32_f16->a_s;
10323 shname = "F32_F16_ALIGNED_S";
10324 } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
10325 p = ctx->device->pipeline_matmul_f16_f32.f32acc->a_s;
10326 shname = "F16_F32_ALIGNED_S";
10327 } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
10328 p = ctx->device->pipeline_matmul_f16.f32acc->a_s;
10329 shname = "F16_ALIGNED_S";
10330 } else {
10331 GGML_ABORT("fatal error");
10332 }
10333 } else if (shader_size == 1) {
10334 if (std::is_same<float, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
10335 p = ctx->device->pipeline_matmul_f32->a_m;
10336 shname = "F32_ALIGNED_M";
10337 } else if (std::is_same<float, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
10338 p = ctx->device->pipeline_matmul_f32_f16->a_m;
10339 shname = "F32_F16_ALIGNED_M";
10340 } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
10341 p = ctx->device->pipeline_matmul_f16_f32.f32acc->a_m;
10342 shname = "F16_F32_ALIGNED_M";
10343 } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
10344 p = ctx->device->pipeline_matmul_f16.f32acc->a_m;
10345 shname = "F16_ALIGNED_M";
10346 } else {
10347 GGML_ABORT("fatal error");
10348 }
10349 } else if (shader_size == 2) {
10350 if (std::is_same<float, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
10351 p = ctx->device->pipeline_matmul_f32->a_l;
10352 shname = "F32_ALIGNED_L";
10353 } else if (std::is_same<float, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
10354 p = ctx->device->pipeline_matmul_f32_f16->a_l;
10355 shname = "F32_F16_ALIGNED_L";
10356 } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
10357 p = ctx->device->pipeline_matmul_f16_f32.f32acc->a_l;
10358 shname = "F16_F32_ALIGNED_L";
10359 } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
10360 p = ctx->device->pipeline_matmul_f16.f32acc->a_l;
10361 shname = "F16_ALIGNED_L";
10362 } else {
10363 GGML_ABORT("fatal error");
10364 }
10365 } else {
10366 GGML_ASSERT(0);
10367 }
10368
10369 const size_t kpad = ggml_vk_align_size(k, p->align);
10370
10371 if (k != kpad) {
10372 if (shader_size == 0) {
10373 if (std::is_same<float, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
10374 p = ctx->device->pipeline_matmul_f32->s;
10375 shname = "F32_S";
10376 } else if (std::is_same<float, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
10377 p = ctx->device->pipeline_matmul_f32_f16->s;
10378 shname = "F32_F16_S";
10379 } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
10380 p = ctx->device->pipeline_matmul_f16_f32.f32acc->s;
10381 shname = "F16_F32_S";
10382 } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
10383 p = ctx->device->pipeline_matmul_f16.f32acc->s;
10384 shname = "F16_S";
10385 }
10386 } else if (shader_size == 1) {
10387 if (std::is_same<float, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
10388 p = ctx->device->pipeline_matmul_f32->m;
10389 shname = "F32_M";
10390 } else if (std::is_same<float, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
10391 p = ctx->device->pipeline_matmul_f32_f16->m;
10392 shname = "F32_F16_M";
10393 } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
10394 p = ctx->device->pipeline_matmul_f16_f32.f32acc->m;
10395 shname = "F16_F32_M";
10396 } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
10397 p = ctx->device->pipeline_matmul_f16.f32acc->m;
10398 shname = "F16_M";
10399 }
10400 } else if (shader_size == 2) {
10401 if (std::is_same<float, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
10402 p = ctx->device->pipeline_matmul_f32->l;
10403 shname = "F32_L";
10404 } else if (std::is_same<float, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
10405 p = ctx->device->pipeline_matmul_f32_f16->l;
10406 shname = "F32_F16_L";
10407 } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
10408 p = ctx->device->pipeline_matmul_f16_f32.f32acc->l;
10409 shname = "F16_F32_L";
10410 } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
10411 p = ctx->device->pipeline_matmul_f16.f32acc->l;
10412 shname = "F16_L";
10413 }
10414 }
10415 }
10416
10417 ggml_pipeline_request_descriptor_sets(ctx, p, num_it);
10418 if (split_k > 1) {
10419 ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, num_it);
10420
10421 if (ctx->prealloc_split_k == nullptr || ctx->prealloc_split_k->size < sizeof(float) * d_ne * split_k) {
10422 // Resize buffer
10423 if (ctx->prealloc_split_k != nullptr) {
10424 ggml_vk_destroy_buffer(ctx->prealloc_split_k);
10425 }
10426 ctx->prealloc_split_k = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne * split_k, {vk::MemoryPropertyFlagBits::eDeviceLocal});
10427 }
10428 }
10429
10430 ggml_pipeline_allocate_descriptor_sets(ctx);
10431
10432 vk_buffer d_X = ggml_vk_create_buffer_check(ctx->device, sizeof(X_TYPE) * x_ne, {vk::MemoryPropertyFlagBits::eDeviceLocal});
10433 vk_buffer d_Y = ggml_vk_create_buffer_check(ctx->device, sizeof(Y_TYPE) * y_ne, {vk::MemoryPropertyFlagBits::eDeviceLocal});
10434 vk_buffer d_D = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne, {vk::MemoryPropertyFlagBits::eDeviceLocal});
10435
10436 X_TYPE* x = (X_TYPE *) malloc(sizeof(X_TYPE) * x_ne);
10437 Y_TYPE* y = (Y_TYPE *) malloc(sizeof(Y_TYPE) * y_ne);
10438 float* d = (float *) malloc(sizeof(float) * d_ne);
10439
10440 for (size_t i = 0; i < x_ne; i++) {
10441 if (std::is_same<float, X_TYPE>()) {
10442 x[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
10443 // x[i] = 1.0f;
10444 // x[i] = i + 1;
10445 // x[i] = (i % k == i / k) ? 1.0f : 0.0f;
10446 } else if (std::is_same<ggml_fp16_t, X_TYPE>()) {
10447 x[i] = ggml_fp32_to_fp16((rand() / (float)RAND_MAX) * 2.0f - 1.0f);
10448 // x[i] = ggml_fp32_to_fp16(1.0f);
10449 // x[i] = ggml_fp32_to_fp16(i + 1);
10450 // x[i] = ggml_fp32_to_fp16((i % k == i / k) ? 1.0f : 0.0f);
10451 } else {
10452 GGML_ABORT("fatal error");
10453 }
10454 }
10455 for (size_t i = 0; i < y_ne; i++) {
10456 if (std::is_same<float, Y_TYPE>()) {
10457 y[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
10458 // y[i] = (i % k == i / k) ? 1.0f : 0.0f;
10459 // y[i] = i + 1;
10460 } else if (std::is_same<ggml_fp16_t, Y_TYPE>()) {
10461 y[i] = ggml_fp32_to_fp16((rand() / (float)RAND_MAX) * 2.0f - 1.0f);
10462 // y[i] = ggml_fp32_to_fp16((i % k == i / k) ? 1.0f : 0.0f);
10463 // y[i] = ggml_fp32_to_fp16(i + 1);
10464 } else {
10465 GGML_ABORT("fatal error");
10466 }
10467 }
10468
10469 ggml_vk_buffer_write(d_X, 0, x, sizeof(X_TYPE) * k * m * batch);
10470 ggml_vk_buffer_write(d_Y, 0, y, sizeof(Y_TYPE) * k * n * batch);
10471
10472 vk_context subctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
10473 ggml_vk_ctx_begin(ctx->device, subctx);
10474 for (size_t i = 0; i < num_it; i++) {
10475 ggml_vk_matmul(
10476 ctx, subctx, p, ggml_vk_subbuffer(ctx, d_X), ggml_vk_subbuffer(ctx, d_Y), ggml_vk_subbuffer(ctx, d_D), ggml_vk_subbuffer(ctx, ctx->prealloc_split_k),
10477 m, n, k,
10478 k, k, m, k*m, k*n, m*n,
10479 split_k, batch, batch, batch, 1, 1, n
10480 );
10481 }
10482 ggml_vk_ctx_end(subctx);
10483
10484 auto begin = std::chrono::high_resolution_clock::now();
10485 ggml_vk_submit(subctx, ctx->fence);
10486 VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_matmul waitForFences");
10487 ctx->device->device.resetFences({ ctx->fence });
10488 ggml_vk_queue_command_pools_cleanup(ctx->device);
10489
10490 auto end = std::chrono::high_resolution_clock::now();
10491 double time = std::chrono::duration_cast<std::chrono::microseconds>(end-begin).count() / 1000.0;
10492
10493 // copy dst to host
10494 ggml_vk_buffer_read(d_D, 0, d, sizeof(float) * d_ne);
10495
10496 float * d_chk = (float *) malloc(sizeof(float) * d_ne);
10497
10498 ggml_init_params iparams = {
10499 /*.mem_size =*/ 1024*1024*1024,
10500 /*.mem_buffer =*/ NULL,
10501 /*.no_alloc =*/ true,
10502 };
10503
10504 ggml_context * ggml_ctx = ggml_init(iparams);
10505
10506 ggml_type src0_type;
10507 ggml_type src1_type;
10508
10509 if (std::is_same<float, X_TYPE>()) {
10510 src0_type = GGML_TYPE_F32;
10511 } else if (std::is_same<ggml_fp16_t, X_TYPE>()) {
10512 src0_type = GGML_TYPE_F16;
10513 } else {
10514 GGML_ABORT("fatal error");
10515 }
10516 if (std::is_same<float, Y_TYPE>()) {
10517 src1_type = GGML_TYPE_F32;
10518 } else if (std::is_same<ggml_fp16_t, Y_TYPE>()) {
10519 src1_type = GGML_TYPE_F16;
10520 } else {
10521 GGML_ABORT("fatal error");
10522 }
10523
10524 ggml_tensor * src0_ggml = ggml_new_tensor_3d(ggml_ctx, src0_type, k, m, batch);
10525 ggml_tensor * src1_ggml = ggml_new_tensor_3d(ggml_ctx, src1_type, k, n, batch);
10526 ggml_tensor * tensor_ggml = ggml_mul_mat(ggml_ctx, src0_ggml, src1_ggml);
10527
10528 src0_ggml->data = x;
10529 src1_ggml->data = y;
10530 tensor_ggml->data = d_chk;
10531
10532 ggml_cgraph * cgraph = ggml_new_graph(ggml_ctx);
10533 ggml_build_forward_expand(cgraph, tensor_ggml);
10534
10535 ggml_graph_compute_with_ctx(ggml_ctx, cgraph, 1);
10536
10537 ggml_free(ggml_ctx);
10538
10539 double avg_err = 0.0;
10540 int first_err_n = -1;
10541 int first_err_m = -1;
10542 int first_err_b = -1;
10543
10544 for (size_t i = 0; i < m*n*batch; i++) {
10545 double err = std::fabs(d[i] - d_chk[i]);
10546 avg_err += err;
10547
10548 if ((err > 0.05f || std::isnan(err)) && first_err_n == -1) {
10549 first_err_b = i / (m * n);
10550 first_err_n = (i % (m * n)) / m;
10551 first_err_m = (i % (m * n)) % m;
10552 }
10553 }
10554
10555 avg_err /= m * n;
10556
10557 double tflops = 2.0*m*n*k*batch*num_it / (time / 1000.0) / (1000.0*1000.0*1000.0*1000.0);
10558
10559 std::cerr << "TEST " << shname << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time / num_it << "ms " << tflops << " TFLOPS avg_err=" << avg_err << std::endl;
10560
10561 if (avg_err > 0.1 || std::isnan(avg_err)) {
10562 std::cerr << "m = " << first_err_m << " n = " << first_err_n << " b = " << first_err_b << std::endl;
10563 std::cerr << "Actual result: " << std::endl << std::endl;
10564 ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
10565 std::cerr << "Expected result: " << std::endl << std::endl;
10566 ggml_vk_print_matrix_area(d_chk, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
10567
10568 if (split_k > 1) {
10569 float * split_k_buf = (float *) malloc(sizeof(float) * d_ne * split_k);
10570 ggml_vk_buffer_read(ctx->prealloc_split_k, 0, split_k_buf, sizeof(float) * d_ne * split_k);
10571
10572 std::cerr << "d_buf0: " << std::endl << std::endl;
10573 ggml_vk_print_matrix_area(split_k_buf, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
10574
10575 std::cerr << "d_buf1: " << std::endl << std::endl;
10576 ggml_vk_print_matrix_area(split_k_buf + d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
10577
10578 std::cerr << "d_buf2: " << std::endl << std::endl;
10579 ggml_vk_print_matrix_area(split_k_buf + 2 * d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
10580
10581 std::cerr << "d_buf3: " << std::endl << std::endl;
10582 ggml_vk_print_matrix_area(split_k_buf + 3 * d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
10583
10584 free(split_k_buf);
10585 }
10586 }
10587
10588 free(d_chk);
10589
10590 ggml_vk_command_pool_cleanup(ctx->device, ctx->compute_cmd_pool);
10591 ggml_vk_command_pool_cleanup(ctx->device, ctx->transfer_cmd_pool);
10592
10593 ggml_vk_destroy_buffer(d_X);
10594 ggml_vk_destroy_buffer(d_Y);
10595 ggml_vk_destroy_buffer(d_D);
10596
10597 free(x);
10598 free(y);
10599 free(d);
10600}
10601
10602static void ggml_vk_print_tensor_area(const ggml_tensor * tensor, int i0, int i1, int i2, int i3) {
10603 if (tensor->type != GGML_TYPE_F32 && tensor->type != GGML_TYPE_F16) {
10604 return;
10605 }
10606 i0 = std::max(i0, 5);
10607 i1 = std::max(i1, 5);
10608 i2 = std::max(i2, 0);
10609 i3 = std::max(i3, 0);
10610 fprintf(stderr, " ");
10611 for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) {
10612 fprintf(stderr, "%7d ", idx1);
10613 }
10614 fprintf(stderr, "\n");
10615 for (int idx0 = i0 - 5; idx0 < i0 + 5; idx0++) {
10616 fprintf(stderr, "%7d: ", idx0);
10617 for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) {
10618 if (idx0 >= 0 && idx0 < tensor->ne[0] && idx1 >= 0 && idx1 < tensor->ne[1] && i2 >= 0 && i2 < tensor->ne[2] && i3 >= 0 && i3 < tensor->ne[3]) {
10619 float val;
10620 if (tensor->type == GGML_TYPE_F32) {
10621 val = *(float *) ((char *) tensor->data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0]);
10622 } else if (tensor->type == GGML_TYPE_F16) {
10623 val = ggml_fp16_to_fp32(*(ggml_fp16_t *) ((char *) tensor->data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0]));
10624 } else {
10625 GGML_ABORT("fatal error");
10626 }
10627 fprintf(stderr, "% 7.2f ", val);
10628 } else {
10629 fprintf(stderr, " ");
10630 }
10631 }
10632 fprintf(stderr, "\n");
10633 }
10634}
10635
10636static void ggml_vk_quantize_data(const float * from, void * to, size_t ne, ggml_type quant) {
10637 ggml_quantize_chunk(quant, from, to, 0, 1, ne, nullptr);
10638}
10639
10640static void ggml_vk_dequantize_data(const void * from, float * to, size_t ne, ggml_type quant) {
10641 if (quant == GGML_TYPE_F32) {
10642 memcpy(to, from, sizeof(float) * ne);
10643 return;
10644 }
10645
10646 const auto * tt = ggml_get_type_traits(quant);
10647
10648 ggml_to_float_t dequant_fn = tt->to_float;
10649
10650 dequant_fn(from, to, ne);
10651}
10652
10653static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_type quant) {
10654 VK_LOG_DEBUG("ggml_vk_test_dequant(" << ne << ")");
10655 const size_t x_sz = sizeof(float) * ne;
10656 const size_t x_sz_f16 = sizeof(ggml_fp16_t) * ne;
10657 const size_t qx_sz = ne * ggml_type_size(quant)/ggml_blck_size(quant);
10658 float * x = (float *) malloc(x_sz);
10659 void * qx = malloc(qx_sz);
10660 vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, {vk::MemoryPropertyFlagBits::eDeviceLocal});
10661 vk_buffer x_buf = ggml_vk_create_buffer_check(ctx->device, x_sz_f16, {vk::MemoryPropertyFlagBits::eDeviceLocal});
10662 float * x_ref = (float *) malloc(x_sz);
10663 ggml_fp16_t * x_chk = (ggml_fp16_t *) malloc(x_sz_f16);
10664
10665 for (size_t i = 0; i < ne; i++) {
10666 x[i] = rand() / (float)RAND_MAX;
10667 }
10668
10669 vk_pipeline p = ggml_vk_get_to_fp16(ctx, quant);
10670
10671 ggml_vk_quantize_data(x, qx, ne, quant);
10672 ggml_vk_dequantize_data(qx, x_ref, ne, quant);
10673
10674 ggml_pipeline_request_descriptor_sets(ctx, p, 1);
10675
10676 ggml_pipeline_allocate_descriptor_sets(ctx);
10677
10678 ggml_vk_buffer_write(qx_buf, 0, qx, qx_sz);
10679
10680 vk_context subctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
10681 ggml_vk_ctx_begin(ctx->device, subctx);
10682 const std::vector<uint32_t> pc = { 1, (uint32_t)ne, (uint32_t)ne, (uint32_t)ne, (uint32_t)ne };
10683 ggml_vk_dispatch_pipeline(ctx, subctx, p, { vk_subbuffer{ qx_buf, 0, qx_sz }, vk_subbuffer{ x_buf, 0, x_sz_f16 } }, pc, { (uint32_t)ne, 1, 1});
10684 ggml_vk_ctx_end(subctx);
10685
10686 auto begin = std::chrono::high_resolution_clock::now();
10687
10688 ggml_vk_submit(subctx, ctx->fence);
10689 VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_dequant waitForFences");
10690 ctx->device->device.resetFences({ ctx->fence });
10691 ggml_vk_queue_command_pools_cleanup(ctx->device);
10692
10693 auto end = std::chrono::high_resolution_clock::now();
10694
10695 double ms_dequant = std::chrono::duration_cast<std::chrono::microseconds>(end-begin).count() / 1000.0;
10696 ggml_vk_buffer_read(x_buf, 0, x_chk, x_sz_f16);
10697
10698 int first_err = -1;
10699
10700 double avg_err = 0.0;
10701 for (size_t i = 0; i < ne; i++) {
10702 double error = std::fabs(x_ref[i] - ggml_fp16_to_fp32(x_chk[i]));
10703 avg_err += error;
10704
10705 if (first_err < 0 && error > 0.05) {
10706 first_err = i;
10707 }
10708 }
10709
10710 avg_err /= ne;
10711
10712 std::cerr << "TEST DEQUANT " << ggml_type_name(quant) << " time=" << ms_dequant << "ms avg_err=" << avg_err << std::endl;
10713
10714 if (avg_err > 0.1) {
10715 std::cerr << "first_error = " << first_err << std::endl;
10716 std::cerr << "Actual result: " << std::endl << std::endl;
10717 for (int i = std::max(0, first_err - 5); i < std::min((int)ne, first_err + 5); i++) {
10718 std::cerr << ggml_fp16_to_fp32(x_chk[i]) << ", ";
10719 }
10720 std::cerr << std::endl << "Expected result: " << std::endl << std::endl;
10721 for (int i = std::max(0, first_err - 5); i < std::min((int)ne, first_err + 5); i++) {
10722 std::cerr << x_ref[i] << ", ";
10723 }
10724 std::cerr << std::endl;
10725 }
10726
10727 ggml_vk_destroy_buffer(x_buf);
10728 ggml_vk_destroy_buffer(qx_buf);
10729
10730 free(x);
10731 free(qx);
10732 free(x_ref);
10733 free(x_chk);
10734}
10735
10736// This does not work without ggml q8_1 quantization support
10737//
10738// typedef uint16_t ggml_half;
10739// typedef uint32_t ggml_half2;
10740//
10741// #define QK8_1 32
10742// typedef struct {
10743// union {
10744// struct {
10745// ggml_half d; // delta
10746// ggml_half s; // d * sum(qs[i])
10747// } GGML_COMMON_AGGR_S;
10748// ggml_half2 ds;
10749// } GGML_COMMON_AGGR_U;
10750// int8_t qs[QK8_1]; // quants
10751// } block_q8_1;
10752//
10753// static void ggml_vk_test_quantize(ggml_backend_vk_context * ctx, size_t ne, ggml_type quant) {
10754// VK_LOG_DEBUG("ggml_vk_test_quantize(" << ne << ")");
10755// GGML_ASSERT(quant == GGML_TYPE_Q8_1);
10756//
10757// const size_t x_sz = sizeof(float) * ne;
10758// const size_t qx_sz = ne * ggml_type_size(quant)/ggml_blck_size(quant);
10759// float * x = (float *) malloc(x_sz);
10760// block_q8_1 * qx = (block_q8_1 *)malloc(qx_sz);
10761// block_q8_1 * qx_res = (block_q8_1 *)malloc(qx_sz);
10762// vk_buffer x_buf = ggml_vk_create_buffer_check(ctx->device, x_sz, {vk::MemoryPropertyFlagBits::eDeviceLocal});
10763// vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, {vk::MemoryPropertyFlagBits::eDeviceLocal});
10764//
10765// for (size_t i = 0; i < ne; i++) {
10766// x[i] = rand() / (float)RAND_MAX;
10767// }
10768//
10769// vk_pipeline p = ggml_vk_get_quantize_pipeline(ctx, quant);
10770//
10771// ggml_pipeline_request_descriptor_sets(ctx, p, 1);
10772//
10773// ggml_pipeline_allocate_descriptor_sets(ctx);
10774//
10775// ggml_vk_buffer_write(x_buf, 0, x, x_sz);
10776//
10777// vk_context subctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
10778// ggml_vk_ctx_begin(ctx->device, subctx);
10779// ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(ctx, x_buf), ggml_vk_subbuffer(ctx, qx_buf), ne);
10780// ggml_vk_ctx_end(subctx);
10781//
10782// auto begin = std::chrono::high_resolution_clock::now();
10783//
10784// ggml_vk_submit(subctx, ctx->fence);
10785// VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_quantize waitForFences");
10786// ctx->device->device.resetFences({ ctx->fence });
10787// ggml_vk_queue_command_pools_cleanup(ctx->device);
10788//
10789// auto end = std::chrono::high_resolution_clock::now();
10790//
10791// double ms_quant = std::chrono::duration_cast<std::chrono::microseconds>(end-begin).count() / 1000.0;
10792// ggml_vk_buffer_read(qx_buf, 0, qx, qx_sz);
10793//
10794// ggml_vk_quantize_data(x, qx_res, ne, quant);
10795//
10796// int first_err = -1;
10797//
10798// for (size_t i = 0; i < ne / 32; i++) {
10799// double error = std::fabs(ggml_fp16_to_fp32(qx_res[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d) - ggml_fp16_to_fp32(qx[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d));
10800//
10801// if (first_err < 0 && error > 0.1) {
10802// first_err = i;
10803// }
10804//
10805// error = std::fabs(ggml_fp16_to_fp32(qx_res[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s) - ggml_fp16_to_fp32(qx[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s));
10806//
10807// if (first_err < 0 && error > 0.1) {
10808// first_err = i;
10809// }
10810//
10811// for (size_t j = 0; j < 32; j++) {
10812// uint64_t error = std::abs(qx_res[i].qs[j] - qx[i].qs[j]);
10813//
10814// if (first_err < 0 && error > 1) {
10815// first_err = i;
10816// }
10817// }
10818// }
10819//
10820// std::cerr << "TEST QUANTIZE " << ggml_type_name(quant) << " time=" << ms_quant << "ms " << (first_err == -1 ? "CORRECT" : "INCORRECT") << std::endl;
10821//
10822// if (first_err != -1) {
10823// std::cerr << "first_error = " << first_err << std::endl;
10824// std::cerr << "Actual result: " << std::endl << std::endl;
10825// std::cout << "d=" << ggml_fp16_to_fp32(qx[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d) << " s=" << ggml_fp16_to_fp32(qx[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s) << " ";
10826// for (size_t j = 0; j < 32; j++) {
10827// std::cout << " qs" << j << "=" << (uint32_t)qx[first_err].qs[j] << " ";
10828// }
10829// std::cerr << std::endl << std::endl << "Expected result: " << std::endl << std::endl;
10830// std::cout << "d=" << ggml_fp16_to_fp32(qx_res[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d) << " s=" << ggml_fp16_to_fp32(qx_res[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s) << " ";
10831// for (size_t j = 0; j < 32; j++) {
10832// std::cout << " qs" << j << "=" << (uint32_t)qx_res[first_err].qs[j] << " ";
10833// }
10834// std::cerr << std::endl;
10835// }
10836//
10837// ggml_vk_destroy_buffer(x_buf);
10838// ggml_vk_destroy_buffer(qx_buf);
10839//
10840// free(x);
10841// free(qx);
10842// free(qx_res);
10843// }
10844
10845static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, size_t n, size_t k, size_t batch, size_t num_it, size_t split_k, size_t shader_size, ggml_type quant, bool mmq = false) {
10846 VK_LOG_DEBUG("ggml_vk_test_dequant_matmul(" << m << ", " << n << ", " << k << ", " << batch << ", " << num_it << ", " << split_k << ", " << ggml_type_name(quant) << ")");
10847 const size_t x_ne = m * k * batch;
10848 const size_t y_ne = k * n * batch;
10849 const size_t d_ne = m * n * batch;
10850
10851 vk_matmul_pipeline2 * pipelines;
10852
10853 if (mmq) {
10854 pipelines = ctx->device->pipeline_dequant_mul_mat_mat_q8_1;
10855 } else {
10856 pipelines = ctx->device->pipeline_dequant_mul_mat_mat;
10857 }
10858
10859 const bool fp16acc = ctx->device->fp16;
10860
10861 vk_pipeline p;
10862 std::string shname;
10863 if (shader_size == 0) {
10864 p = fp16acc ? pipelines[quant].f16acc->a_s : pipelines[quant].f32acc->a_s;
10865 shname = std::string(ggml_type_name(quant)) + "_ALIGNED_S";
10866 } else if (shader_size == 1) {
10867 p = fp16acc ? pipelines[quant].f16acc->a_m : pipelines[quant].f32acc->a_m;
10868 shname = std::string(ggml_type_name(quant)) + "_ALIGNED_M";
10869 } else if (shader_size == 2) {
10870 p = fp16acc ? pipelines[quant].f16acc->a_l : pipelines[quant].f32acc->a_l;
10871 shname = std::string(ggml_type_name(quant)) + "_ALIGNED_L";
10872 } else {
10873 GGML_ASSERT(0);
10874 }
10875
10876 const size_t kpad = mmq ? 0 : ggml_vk_align_size(k, p->align);
10877
10878 if (mmq || k != kpad) {
10879 if (shader_size == 0) {
10880 p = fp16acc ? pipelines[quant].f16acc->s : pipelines[quant].f32acc->s;
10881 shname = std::string(ggml_type_name(quant)) + "_S";
10882 } else if (shader_size == 1) {
10883 p = fp16acc ? pipelines[quant].f16acc->m : pipelines[quant].f32acc->m;
10884 shname = std::string(ggml_type_name(quant)) + "_M";
10885 } else if (shader_size == 2) {
10886 p = fp16acc ? pipelines[quant].f16acc->l : pipelines[quant].f32acc->l;
10887 shname = std::string(ggml_type_name(quant)) + "_L";
10888 } else {
10889 GGML_ASSERT(0);
10890 }
10891 }
10892
10893 if (p == nullptr) {
10894 std::cerr << "error: no pipeline for ggml_vk_test_dequant_matmul " << ggml_type_name(quant) << std::endl;
10895 return;
10896 }
10897
10898 const size_t x_sz = sizeof(float) * x_ne;
10899 const size_t y_sz = sizeof(float) * y_ne;
10900 const size_t qx_sz = x_ne * ggml_type_size(quant)/ggml_blck_size(quant);
10901 const size_t qy_sz = mmq ? y_ne * ggml_type_size(GGML_TYPE_Q8_1)/ggml_blck_size(GGML_TYPE_Q8_1) : y_sz;
10902 const size_t d_sz = sizeof(float) * d_ne;
10903 float * x = (float *) malloc(x_sz);
10904 float * y = (float *) malloc(y_sz);
10905 void * qx = malloc(qx_sz);
10906 vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, {vk::MemoryPropertyFlagBits::eDeviceLocal});
10907 vk_buffer y_buf = ggml_vk_create_buffer_check(ctx->device, y_sz, {vk::MemoryPropertyFlagBits::eDeviceLocal});
10908 vk_buffer qy_buf = ggml_vk_create_buffer_check(ctx->device, qy_sz, {vk::MemoryPropertyFlagBits::eDeviceLocal});
10909 vk_buffer d_buf = ggml_vk_create_buffer_check(ctx->device, d_sz, {vk::MemoryPropertyFlagBits::eDeviceLocal});
10910 float * d = (float *) malloc(d_sz);
10911 float * d_chk = (float *) malloc(d_sz);
10912
10913 for (size_t i = 0; i < x_ne; i++) {
10914 x[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
10915 // x[i] = (i % k == i / k) ? 1.0f : 0.0f;
10916 // x[i] = i % k;
10917 }
10918
10919 ggml_vk_quantize_data(x, qx, x_ne, quant);
10920
10921 for (size_t i = 0; i < y_ne; i++) {
10922 y[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
10923 // y[i] = (i % k == i / k) ? 1.0f : 0.0f;
10924 // y[i] = i % k;
10925 }
10926
10927 ggml_pipeline_request_descriptor_sets(ctx, p, num_it);
10928 if (split_k > 1) {
10929 ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, num_it);
10930
10931 if (ctx->prealloc_split_k == nullptr || ctx->prealloc_split_k->size < sizeof(float) * d_ne * split_k) {
10932 // Resize buffer
10933 if (ctx->prealloc_split_k != nullptr) {
10934 ggml_vk_destroy_buffer(ctx->prealloc_split_k);
10935 }
10936 ctx->prealloc_split_k = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne * split_k, {vk::MemoryPropertyFlagBits::eDeviceLocal});
10937 }
10938 }
10939 if (mmq) {
10940 ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_quantize_q8_1, num_it);
10941 }
10942
10943 ggml_pipeline_allocate_descriptor_sets(ctx);
10944
10945 ggml_vk_buffer_write(qx_buf, 0, qx, qx_sz);
10946 ggml_vk_buffer_write(y_buf, 0, y, y_sz);
10947
10948 vk_context subctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
10949 ggml_vk_ctx_begin(ctx->device, subctx);
10950 if (mmq) {
10951 for (size_t i = 0; i < num_it; i++) {
10952 ggml_vk_quantize_q8_1(ctx, subctx, { y_buf, 0, y_sz }, { qy_buf, 0, qy_sz }, y_ne);
10953 ggml_vk_matmul(
10954 ctx, subctx, p, { qx_buf, 0, qx_sz }, { qy_buf, 0, qy_sz }, { d_buf, 0, d_sz }, { ctx->prealloc_split_k, 0, ctx->prealloc_size_split_k },
10955 m, n, k,
10956 k, k, m, k*m, k*n, m*n,
10957 split_k, batch, batch, batch, 1, 1, n
10958 );
10959 }
10960 } else {
10961 for (size_t i = 0; i < num_it; i++) {
10962 ggml_vk_matmul(
10963 ctx, subctx, p, { qx_buf, 0, qx_sz }, { y_buf, 0, y_sz }, { d_buf, 0, d_sz }, { ctx->prealloc_split_k, 0, ctx->prealloc_size_split_k },
10964 m, n, k,
10965 k, k, m, k*m, k*n, m*n,
10966 split_k, batch, batch, batch, 1, 1, n
10967 );
10968 }
10969 }
10970 ggml_vk_ctx_end(subctx);
10971
10972 auto begin = std::chrono::high_resolution_clock::now();
10973
10974 ggml_vk_submit(subctx, ctx->fence);
10975 VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_dequant waitForFences");
10976 ctx->device->device.resetFences({ ctx->fence });
10977 ggml_vk_queue_command_pools_cleanup(ctx->device);
10978
10979 auto end = std::chrono::high_resolution_clock::now();
10980
10981 double time_ms = std::chrono::duration_cast<std::chrono::microseconds>(end-begin).count() / 1000.0;
10982 ggml_vk_buffer_read(d_buf, 0, d, d_sz);
10983
10984 ggml_init_params iparams = {
10985 /*.mem_size =*/ 1024*1024*1024,
10986 /*.mem_buffer =*/ NULL,
10987 /*.no_alloc =*/ true,
10988 };
10989
10990 ggml_context * ggml_ctx = ggml_init(iparams);
10991
10992 ggml_tensor * src0_ggml = ggml_new_tensor_3d(ggml_ctx, quant, k, m, batch);
10993 ggml_tensor * src1_ggml = ggml_new_tensor_3d(ggml_ctx, GGML_TYPE_F32, k, n, batch);
10994 ggml_tensor * tensor_ggml = ggml_mul_mat(ggml_ctx, src0_ggml, src1_ggml);
10995
10996 src0_ggml->data = qx;
10997 src1_ggml->data = y;
10998 tensor_ggml->data = d_chk;
10999
11000 ggml_cgraph * cgraph = ggml_new_graph(ggml_ctx);
11001 ggml_build_forward_expand(cgraph, tensor_ggml);
11002
11003 ggml_graph_compute_with_ctx(ggml_ctx, cgraph, 1);
11004
11005 ggml_free(ggml_ctx);
11006
11007 double avg_err = 0.0;
11008 int first_err_n = -1;
11009 int first_err_m = -1;
11010 int first_err_b = -1;
11011
11012 for (size_t i = 0; i < m*n*batch; i++) {
11013 double err = std::fabs(d[i] - d_chk[i]);
11014 avg_err += err;
11015
11016 if ((err > 0.05f || std::isnan(err)) && first_err_n == -1) {
11017 first_err_b = i / (m * n);
11018 first_err_n = (i % (m * n)) / m;
11019 first_err_m = (i % (m * n)) % m;
11020 }
11021 }
11022
11023 avg_err /= m * n;
11024
11025 double tflops = 2.0*m*n*k*batch*num_it / (time_ms / 1000.0) / (1000.0*1000.0*1000.0*1000.0);
11026
11027 std::cerr << "TEST dequant matmul " << shname;
11028 if (mmq) {
11029 std::cerr << " mmq";
11030 }
11031 std::cerr << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time_ms / num_it << "ms " << tflops << " TFLOPS avg_err=" << avg_err << std::endl;
11032
11033 if (avg_err > 0.01 || std::isnan(avg_err)) {
11034 std::cerr << "m = " << first_err_m << " n = " << first_err_n << " b = " << first_err_b << std::endl;
11035 std::cerr << "Actual result: " << std::endl << std::endl;
11036 ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
11037 std::cerr << std::endl;
11038 std::cerr << "Expected result: " << std::endl << std::endl;
11039 ggml_vk_print_matrix_area(d_chk, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
11040
11041 std::cerr << "src0: " << std::endl << std::endl;
11042 ggml_vk_print_matrix_area(x, GGML_TYPE_F32, k, m, first_err_m, first_err_n, first_err_b);
11043 std::cerr << std::endl;
11044 std::cerr << "src1: " << std::endl << std::endl;
11045 ggml_vk_print_matrix_area(y, GGML_TYPE_F32, k, n, first_err_m, first_err_n, first_err_b);
11046
11047 if (split_k > 1) {
11048 float * split_k_buf = (float *) malloc(sizeof(float) * d_ne * split_k);
11049 ggml_vk_buffer_read(ctx->prealloc_split_k, 0, split_k_buf, sizeof(float) * d_ne * split_k);
11050
11051 std::cerr << "d_buf0: " << std::endl << std::endl;
11052 ggml_vk_print_matrix_area(split_k_buf, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
11053
11054 std::cerr << "d_buf1: " << std::endl << std::endl;
11055 ggml_vk_print_matrix_area(split_k_buf + d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
11056
11057 std::cerr << "d_buf2: " << std::endl << std::endl;
11058 ggml_vk_print_matrix_area(split_k_buf + 2 * d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
11059
11060 std::cerr << "d_buf3: " << std::endl << std::endl;
11061 ggml_vk_print_matrix_area(split_k_buf + 3 * d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
11062
11063 free(split_k_buf);
11064 }
11065 }
11066
11067 ggml_vk_destroy_buffer(qx_buf);
11068 ggml_vk_destroy_buffer(y_buf);
11069 ggml_vk_destroy_buffer(qy_buf);
11070 ggml_vk_destroy_buffer(d_buf);
11071
11072 free(x);
11073 free(qx);
11074 free(y);
11075 free(d);
11076 free(d_chk);
11077}
11078#endif
11079
11080static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx, vk_context subctx) {
11081#if defined(GGML_VULKAN_RUN_TESTS)
11082 const std::vector<size_t> vals {
11083 512, 512, 128,
11084 128, 512, 512,
11085 4096, 512, 4096,
11086 11008, 512, 4096,
11087 4096, 512, 11008,
11088 32000, 512, 4096,
11089 8, 8, 8,
11090 100, 46, 576,
11091 623, 111, 128,
11092 100, 46, 558,
11093 512, 1, 256,
11094 128, 110, 622,
11095 511, 511, 127,
11096 511, 511, 7,
11097 511, 511, 17,
11098 49, 49, 128,
11099 128, 49, 49,
11100 4096, 49, 4096,
11101 };
11102 const size_t num_it = 100;
11103
11104 ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q4_0);
11105 ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q4_0);
11106 ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q4_0);
11107
11108 ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q4_0, true);
11109 ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q4_0, true);
11110 ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q4_0, true);
11111
11112 ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q8_0);
11113 ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q8_0);
11114 ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q8_0);
11115
11116 ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q8_0, true);
11117 ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q8_0, true);
11118 ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q8_0, true);
11119
11120 abort();
11121
11122 for (size_t i = 0; i < vals.size(); i += 3) {
11123 ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0);
11124 ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 1);
11125 ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 2);
11126 std::cerr << '\n';
11127 ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 0);
11128 ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 1);
11129 ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 2);
11130 std::cerr << '\n';
11131 ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 0);
11132 ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 1);
11133 ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 2);
11134 std::cerr << '\n' << std::endl;
11135
11136 if (vals[i + 2] % 32 == 0) {
11137 ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0, GGML_TYPE_Q4_0);
11138 ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 1, GGML_TYPE_Q4_0);
11139 ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 2, GGML_TYPE_Q4_0);
11140 std::cerr << '\n';
11141 ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 0, GGML_TYPE_Q4_0);
11142 ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 1, GGML_TYPE_Q4_0);
11143 ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 2, GGML_TYPE_Q4_0);
11144 std::cerr << '\n';
11145 ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 0, GGML_TYPE_Q4_0);
11146 ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 1, GGML_TYPE_Q4_0);
11147 ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 2, GGML_TYPE_Q4_0);
11148 std::cerr << '\n' << std::endl;
11149 }
11150
11151 if (vals[i + 2] % 256 == 0) {
11152 ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0, GGML_TYPE_Q4_K);
11153 ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 1, GGML_TYPE_Q4_K);
11154 ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 2, GGML_TYPE_Q4_K);
11155 std::cerr << '\n';
11156 ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 0, GGML_TYPE_Q4_K);
11157 ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 1, GGML_TYPE_Q4_K);
11158 ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 2, GGML_TYPE_Q4_K);
11159 std::cerr << '\n';
11160 ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 0, GGML_TYPE_Q4_K);
11161 ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 1, GGML_TYPE_Q4_K);
11162 ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 2, GGML_TYPE_Q4_K);
11163 std::cerr << '\n' << std::endl;
11164 }
11165 }
11166
11167 GGML_ABORT("fatal error");
11168#endif
11169
11170 if (subctx) {
11171 // Submit and wait for any pending work before reallocating the buffers
11172 ggml_vk_ctx_end(ctx&: subctx);
11173 ggml_vk_submit(ctx&: subctx, fence: ctx->fence);
11174 ggml_vk_wait_for_fence(ctx);
11175 ggml_vk_ctx_begin(device&: ctx->device, subctx);
11176 }
11177
11178 if (ctx->prealloc_x == nullptr || (ctx->prealloc_size_x > 0 && ctx->prealloc_x->size < ctx->prealloc_size_x)) {
11179 VK_LOG_MEMORY("ggml_vk_preallocate_buffers(x_size: " << ctx->prealloc_size_x << ")");
11180 // Resize buffer
11181 if (ctx->prealloc_x != nullptr) {
11182 ggml_vk_destroy_buffer(buf&: ctx->prealloc_x);
11183 }
11184 ctx->prealloc_x = ggml_vk_create_buffer_device(device&: ctx->device, size: ctx->prealloc_size_x);
11185 }
11186 if (ctx->prealloc_y == nullptr || (ctx->prealloc_size_y > 0 && ctx->prealloc_y->size < ctx->prealloc_size_y)) {
11187 VK_LOG_MEMORY("ggml_vk_preallocate_buffers(y_size: " << ctx->prealloc_size_y << ")");
11188 // Resize buffer
11189 if (ctx->prealloc_y != nullptr) {
11190 ggml_vk_destroy_buffer(buf&: ctx->prealloc_y);
11191 }
11192 ctx->prealloc_y = ggml_vk_create_buffer_device(device&: ctx->device, size: ctx->prealloc_size_y);
11193 }
11194 if (ctx->prealloc_split_k == nullptr || (ctx->prealloc_size_split_k > 0 && ctx->prealloc_split_k->size < ctx->prealloc_size_split_k)) {
11195 VK_LOG_MEMORY("ggml_vk_preallocate_buffers(split_k_size: " << ctx->prealloc_size_split_k << ")");
11196 // Resize buffer
11197 if (ctx->prealloc_split_k != nullptr) {
11198 ggml_vk_destroy_buffer(buf&: ctx->prealloc_split_k);
11199 }
11200 ctx->prealloc_split_k = ggml_vk_create_buffer_device(device&: ctx->device, size: ctx->prealloc_size_split_k);
11201 }
11202 if (ctx->prealloc_add_rms_partials == nullptr || (ctx->prealloc_size_add_rms_partials > 0 && ctx->prealloc_add_rms_partials->size < ctx->prealloc_size_add_rms_partials)) {
11203 VK_LOG_MEMORY("ggml_vk_preallocate_buffers(add_partials_size: " << ctx->prealloc_add_rms_partials << ")");
11204 // Resize buffer
11205 if (ctx->prealloc_add_rms_partials != nullptr) {
11206 ggml_vk_destroy_buffer(buf&: ctx->prealloc_add_rms_partials);
11207 }
11208 ctx->prealloc_add_rms_partials = ggml_vk_create_buffer_device(device&: ctx->device, size: ctx->prealloc_size_add_rms_partials);
11209 }
11210}
11211
11212static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_cgraph * cgraph, ggml_tensor* tensor, int tensor_idx, bool use_fence, bool almost_ready);
11213
11214// Returns true if node has enqueued work into the queue, false otherwise
11215// If submit is true the current all operations queued so far are being submitted to Vulkan to overlap cmdlist creation and GPU execution.
11216static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int node_idx, ggml_tensor *node_begin, int node_idx_begin, bool last_node, bool almost_ready, bool submit){
11217 ggml_tensor * node = cgraph->nodes[node_idx];
11218 if (ggml_is_empty(tensor: node) || !node->buffer) {
11219 return false;
11220 }
11221
11222 VK_LOG_DEBUG("ggml_vk_build_graph(" << node << ", " << ggml_op_name(node->op) << ")");
11223 ctx->semaphore_idx = 0;
11224
11225 ggml_tensor * src0 = node->src[0];
11226 ggml_tensor * src1 = node->src[1];
11227 ggml_tensor * src2 = node->src[2];
11228 ggml_tensor * src3 = node->src[3];
11229
11230 switch (node->op) {
11231 // Return on empty ops to avoid generating a compute_ctx and setting exit_tensor
11232 case GGML_OP_RESHAPE:
11233 case GGML_OP_VIEW:
11234 case GGML_OP_PERMUTE:
11235 case GGML_OP_TRANSPOSE:
11236 case GGML_OP_NONE:
11237 return false;
11238 case GGML_OP_UNARY:
11239 switch (ggml_get_unary_op(tensor: node)) {
11240 case GGML_UNARY_OP_EXP:
11241 case GGML_UNARY_OP_SILU:
11242 case GGML_UNARY_OP_GELU:
11243 case GGML_UNARY_OP_GELU_ERF:
11244 case GGML_UNARY_OP_GELU_QUICK:
11245 case GGML_UNARY_OP_RELU:
11246 case GGML_UNARY_OP_TANH:
11247 case GGML_UNARY_OP_SIGMOID:
11248 case GGML_UNARY_OP_HARDSIGMOID:
11249 case GGML_UNARY_OP_HARDSWISH:
11250 break;
11251 default:
11252 return false;
11253 }
11254 break;
11255 case GGML_OP_GLU:
11256 switch (ggml_get_glu_op(tensor: node)) {
11257 case GGML_GLU_OP_GEGLU:
11258 case GGML_GLU_OP_REGLU:
11259 case GGML_GLU_OP_SWIGLU:
11260 case GGML_GLU_OP_SWIGLU_OAI:
11261 case GGML_GLU_OP_GEGLU_ERF:
11262 case GGML_GLU_OP_GEGLU_QUICK:
11263 break;
11264 default:
11265 return false;
11266 }
11267 break;
11268 case GGML_OP_ADD:
11269 {
11270 int next_node_idx = node_idx + 1 + ctx->num_additional_fused_ops;
11271 if (next_node_idx < cgraph->n_nodes &&
11272 cgraph->nodes[next_node_idx]->op == GGML_OP_RMS_NORM &&
11273 cgraph->nodes[next_node_idx]->src[0] == cgraph->nodes[next_node_idx - 1] &&
11274 ggml_nrows(tensor: cgraph->nodes[next_node_idx]) == 1 &&
11275 ctx->device->add_rms_fusion) {
11276 uint32_t size = ggml_vk_rms_partials_size(ctx, node: cgraph->nodes[node_idx]);
11277 ctx->do_add_rms_partials_offset_calculation = true;
11278 if (ctx->prealloc_size_add_rms_partials_offset + size <= ctx->prealloc_size_add_rms_partials) {
11279 ctx->do_add_rms_partials = true;
11280 }
11281 }
11282 } break;
11283 case GGML_OP_REPEAT:
11284 case GGML_OP_REPEAT_BACK:
11285 case GGML_OP_GET_ROWS:
11286 case GGML_OP_ADD_ID:
11287 case GGML_OP_ACC:
11288 case GGML_OP_SUB:
11289 case GGML_OP_MUL:
11290 case GGML_OP_DIV:
11291 case GGML_OP_CONCAT:
11292 case GGML_OP_UPSCALE:
11293 case GGML_OP_SCALE:
11294 case GGML_OP_SQR:
11295 case GGML_OP_SQRT:
11296 case GGML_OP_SIN:
11297 case GGML_OP_COS:
11298 case GGML_OP_CLAMP:
11299 case GGML_OP_PAD:
11300 case GGML_OP_ROLL:
11301 case GGML_OP_CPY:
11302 case GGML_OP_SET_ROWS:
11303 case GGML_OP_CONT:
11304 case GGML_OP_DUP:
11305 case GGML_OP_SILU_BACK:
11306 case GGML_OP_NORM:
11307 case GGML_OP_GROUP_NORM:
11308 case GGML_OP_RMS_NORM:
11309 case GGML_OP_RMS_NORM_BACK:
11310 case GGML_OP_L2_NORM:
11311 case GGML_OP_DIAG_MASK_INF:
11312 case GGML_OP_SOFT_MAX:
11313 case GGML_OP_SOFT_MAX_BACK:
11314 case GGML_OP_ROPE:
11315 case GGML_OP_ROPE_BACK:
11316 case GGML_OP_MUL_MAT:
11317 case GGML_OP_MUL_MAT_ID:
11318 case GGML_OP_ARGSORT:
11319 case GGML_OP_SUM:
11320 case GGML_OP_SUM_ROWS:
11321 case GGML_OP_MEAN:
11322 case GGML_OP_ARGMAX:
11323 case GGML_OP_COUNT_EQUAL:
11324 case GGML_OP_IM2COL:
11325 case GGML_OP_IM2COL_3D:
11326 case GGML_OP_TIMESTEP_EMBEDDING:
11327 case GGML_OP_CONV_TRANSPOSE_1D:
11328 case GGML_OP_POOL_2D:
11329 case GGML_OP_CONV_2D:
11330 case GGML_OP_CONV_TRANSPOSE_2D:
11331 case GGML_OP_CONV_2D_DW:
11332 case GGML_OP_RWKV_WKV6:
11333 case GGML_OP_RWKV_WKV7:
11334 case GGML_OP_SSM_SCAN:
11335 case GGML_OP_SSM_CONV:
11336 case GGML_OP_LEAKY_RELU:
11337 case GGML_OP_FLASH_ATTN_EXT:
11338 case GGML_OP_OPT_STEP_ADAMW:
11339 case GGML_OP_OPT_STEP_SGD:
11340 break;
11341 default:
11342 std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(op: node->op) << std::endl;
11343 GGML_ABORT("fatal error");
11344 }
11345
11346 vk_context compute_ctx;
11347
11348 if (ctx->compute_ctx.expired()) {
11349 compute_ctx = ggml_vk_create_context(ctx, p&: ctx->compute_cmd_pool);
11350 ctx->compute_ctx = compute_ctx;
11351 ggml_vk_ctx_begin(device&: ctx->device, subctx&: compute_ctx);
11352 } else {
11353 compute_ctx = ctx->compute_ctx.lock();
11354 }
11355
11356 {
11357 // This logic detects dependencies between modes in the graph and calls ggml_vk_sync_buffers
11358 // to synchronize them. This handles most "normal" synchronization when computing the graph, and when
11359 // there is no auxiliary memory use, it shouldn't be necessary to call ggml_vk_sync_buffers
11360 // outside of this logic. When a node uses one of the prealloc buffers for something like
11361 // dequantization or split_k, additional synchronization is needed between those passes.
11362 bool need_sync = false;
11363
11364 // Check whether "node" requires synchronization. The node requires synchronization if it
11365 // overlaps in memory with another unsynchronized node and at least one of them is a write.
11366 // Destination nodes are checked against both the written/read lists. Source nodes are only
11367 // checked against the written list. Two nodes overlap in memory if they come from the same
11368 // buffer and the tensor or view ranges overlap.
11369 auto const &overlaps_unsynced = [&](const ggml_tensor *node, const std::vector<const ggml_tensor *> &unsynced_nodes) -> bool {
11370 if (unsynced_nodes.size() == 0) {
11371 return false;
11372 }
11373 auto n_base = vk_tensor_offset(tensor: node) + node->view_offs;
11374 auto n_size = ggml_nbytes(tensor: node);
11375 ggml_backend_vk_buffer_context * a_buf_ctx = (ggml_backend_vk_buffer_context *)node->buffer->context;
11376 vk_buffer a_buf = a_buf_ctx->dev_buffer;
11377 for (auto &other : unsynced_nodes) {
11378 ggml_backend_vk_buffer_context * o_buf_ctx = (ggml_backend_vk_buffer_context *)other->buffer->context;
11379 vk_buffer o_buf = o_buf_ctx->dev_buffer;
11380 if (a_buf == o_buf) {
11381 auto o_base = vk_tensor_offset(tensor: other) + other->view_offs;
11382 auto o_size = ggml_nbytes(tensor: other);
11383
11384 if ((o_base <= n_base && n_base < o_base + o_size) ||
11385 (n_base <= o_base && o_base < n_base + n_size)) {
11386 return true;
11387 }
11388 }
11389 }
11390 return false;
11391 };
11392
11393 // For all fused ops, check if the destination node or any of the source
11394 // nodes require synchronization.
11395 for (int32_t i = 0; i < ctx->num_additional_fused_ops + 1 && !need_sync; ++i) {
11396 const ggml_tensor *cur_node = cgraph->nodes[node_idx + i];
11397 // If the node actually writes to memory, then check if it needs to sync
11398 if (ctx->fused_ops_write_mask & (1 << i)) {
11399 if (overlaps_unsynced(cur_node, ctx->unsynced_nodes_read) || overlaps_unsynced(cur_node, ctx->unsynced_nodes_written)) {
11400 need_sync = true;
11401 break;
11402 }
11403 }
11404 for (uint32_t j = 0; j < GGML_MAX_SRC; ++j) {
11405 if (!cur_node->src[j]) {
11406 continue;
11407 }
11408 if (overlaps_unsynced(cur_node->src[j], ctx->unsynced_nodes_written)) {
11409 need_sync = true;
11410 break;
11411 }
11412 }
11413 }
11414
11415#define ENABLE_SYNC_LOGGING 0
11416
11417 if (need_sync) {
11418#if ENABLE_SYNC_LOGGING
11419 std::cerr << "sync" << std::endl;
11420#endif
11421 ctx->unsynced_nodes_written.clear();
11422 ctx->unsynced_nodes_read.clear();
11423 ggml_vk_sync_buffers(ctx, subctx&: compute_ctx);
11424 }
11425 // Add all fused nodes to the unsynchronized lists.
11426 for (int32_t i = 0; i < ctx->num_additional_fused_ops + 1; ++i) {
11427 const ggml_tensor *cur_node = cgraph->nodes[node_idx + i];
11428 // Multiple outputs could be written, e.g. in topk_moe. Add them all to the list.
11429 if (ctx->fused_ops_write_mask & (1 << i)) {
11430 ctx->unsynced_nodes_written.push_back(x: cur_node);
11431 }
11432 for (uint32_t j = 0; j < GGML_MAX_SRC; ++j) {
11433 if (!cur_node->src[j]) {
11434 continue;
11435 }
11436 ctx->unsynced_nodes_read.push_back(x: cur_node->src[j]);
11437 }
11438 }
11439 }
11440#if ENABLE_SYNC_LOGGING
11441 for (int i = 0; i < ctx->num_additional_fused_ops + 1; ++i) {
11442 auto *n = cgraph->nodes[node_idx + i];
11443 std::cerr << node_idx + i << " " << ggml_op_name(n->op) << " " << n->name;
11444 if (n->op == GGML_OP_GLU) {
11445 std::cerr << " " << ggml_glu_op_name(ggml_get_glu_op(n)) << " " << (n->src[1] ? "split" : "single") << " ";
11446 }
11447 if (n->op == GGML_OP_ROPE) {
11448 const int mode = ((const int32_t *) n->op_params)[2];
11449 std::cerr << " rope mode: " << mode;
11450 }
11451 std::cerr << std::endl;
11452 }
11453#endif
11454
11455 switch (node->op) {
11456 case GGML_OP_REPEAT:
11457 ggml_vk_repeat(ctx, subctx&: compute_ctx, src0, dst: node);
11458
11459 break;
11460 case GGML_OP_REPEAT_BACK:
11461 ggml_vk_repeat_back(ctx, subctx&: compute_ctx, src0, dst: node);
11462
11463 break;
11464 case GGML_OP_ACC:
11465 ggml_vk_acc(ctx, subctx&: compute_ctx, src0, src1, dst: node);
11466
11467 break;
11468 case GGML_OP_GET_ROWS:
11469 ggml_vk_get_rows(ctx, subctx&: compute_ctx, src0, src1, dst: node);
11470
11471 break;
11472 case GGML_OP_ADD:
11473 if (ctx->num_additional_fused_ops) {
11474 ggml_vk_multi_add(ctx, subctx&: compute_ctx, cgraph, node_idx);
11475 } else {
11476 ggml_vk_add(ctx, subctx&: compute_ctx, src0, src1, dst: node);
11477 }
11478 break;
11479 case GGML_OP_SUB:
11480 ggml_vk_sub(ctx, subctx&: compute_ctx, src0, src1, dst: node);
11481
11482 break;
11483 case GGML_OP_MUL:
11484 ggml_vk_mul(ctx, subctx&: compute_ctx, src0, src1, dst: node);
11485
11486 break;
11487 case GGML_OP_DIV:
11488 ggml_vk_div(ctx, subctx&: compute_ctx, src0, src1, dst: node);
11489
11490 break;
11491 case GGML_OP_ADD_ID:
11492 ggml_vk_add_id(ctx, subctx&: compute_ctx, src0, src1, src2, dst: node);
11493
11494 break;
11495 case GGML_OP_CONCAT:
11496 ggml_vk_concat(ctx, subctx&: compute_ctx, src0, src1, dst: node);
11497
11498 break;
11499 case GGML_OP_UPSCALE:
11500 ggml_vk_upscale(ctx, subctx&: compute_ctx, src0, dst: node);
11501
11502 break;
11503 case GGML_OP_SCALE:
11504 ggml_vk_scale(ctx, subctx&: compute_ctx, src0, dst: node);
11505
11506 break;
11507 case GGML_OP_SQR:
11508 ggml_vk_sqr(ctx, subctx&: compute_ctx, src0, dst: node);
11509
11510 break;
11511 case GGML_OP_SQRT:
11512 ggml_vk_sqrt(ctx, subctx&: compute_ctx, src0, dst: node);
11513
11514 break;
11515 case GGML_OP_SIN:
11516 ggml_vk_sin(ctx, subctx&: compute_ctx, src0, dst: node);
11517
11518 break;
11519 case GGML_OP_COS:
11520 ggml_vk_cos(ctx, subctx&: compute_ctx, src0, dst: node);
11521
11522 break;
11523 case GGML_OP_CLAMP:
11524 ggml_vk_clamp(ctx, subctx&: compute_ctx, src0, dst: node);
11525
11526 break;
11527 case GGML_OP_PAD:
11528 ggml_vk_pad(ctx, subctx&: compute_ctx, src0, dst: node);
11529
11530 break;
11531 case GGML_OP_ROLL:
11532 ggml_vk_roll(ctx, subctx&: compute_ctx, src0, dst: node);
11533
11534 break;
11535 case GGML_OP_CPY:
11536 case GGML_OP_CONT:
11537 case GGML_OP_DUP:
11538 ggml_vk_cpy(ctx, subctx&: compute_ctx, src0, dst: node);
11539
11540 break;
11541 case GGML_OP_SET_ROWS:
11542 ggml_vk_set_rows(ctx, subctx&: compute_ctx, src0, src1, dst: node);
11543
11544 break;
11545 case GGML_OP_SILU_BACK:
11546 ggml_vk_silu_back(ctx, subctx&: compute_ctx, src0, src1, dst: node);
11547
11548 break;
11549 case GGML_OP_NORM:
11550 ggml_vk_norm(ctx, subctx&: compute_ctx, src0, dst: node);
11551
11552 break;
11553 case GGML_OP_GROUP_NORM:
11554 ggml_vk_group_norm(ctx, subctx&: compute_ctx, src0, dst: node);
11555
11556 break;
11557 case GGML_OP_RMS_NORM:
11558 ggml_vk_rms_norm(ctx, subctx&: compute_ctx, cgraph, node_idx, op_params: (float *)node->op_params);
11559 break;
11560 case GGML_OP_RMS_NORM_BACK:
11561 ggml_vk_rms_norm_back(ctx, subctx&: compute_ctx, src0, src1, dst: node);
11562
11563 break;
11564 case GGML_OP_L2_NORM:
11565 ggml_vk_l2_norm(ctx, subctx&: compute_ctx, src0, dst: node);
11566
11567 break;
11568 case GGML_OP_UNARY:
11569 switch (ggml_get_unary_op(tensor: node)) {
11570 case GGML_UNARY_OP_EXP:
11571 case GGML_UNARY_OP_SILU:
11572 case GGML_UNARY_OP_GELU:
11573 case GGML_UNARY_OP_GELU_ERF:
11574 case GGML_UNARY_OP_GELU_QUICK:
11575 case GGML_UNARY_OP_RELU:
11576 case GGML_UNARY_OP_TANH:
11577 case GGML_UNARY_OP_SIGMOID:
11578 case GGML_UNARY_OP_HARDSIGMOID:
11579 case GGML_UNARY_OP_HARDSWISH:
11580 ggml_vk_unary(ctx, subctx&: compute_ctx, src0, dst: node);
11581 break;
11582 default:
11583 return false;
11584 }
11585 break;
11586 case GGML_OP_GLU:
11587 switch (ggml_get_glu_op(tensor: node)) {
11588 case GGML_GLU_OP_GEGLU:
11589 case GGML_GLU_OP_REGLU:
11590 case GGML_GLU_OP_SWIGLU:
11591 case GGML_GLU_OP_SWIGLU_OAI:
11592 case GGML_GLU_OP_GEGLU_ERF:
11593 case GGML_GLU_OP_GEGLU_QUICK:
11594 ggml_vk_glu(ctx, subctx&: compute_ctx, src0, src1, dst: node);
11595 break;
11596 default:
11597 return false;
11598 }
11599 break;
11600 case GGML_OP_DIAG_MASK_INF:
11601 ggml_vk_diag_mask_inf(ctx, subctx&: compute_ctx, src0, dst: node);
11602
11603 break;
11604 case GGML_OP_SOFT_MAX:
11605 if (ctx->num_additional_fused_ops) {
11606 ggml_vk_topk_moe(ctx, subctx&: compute_ctx, cgraph, node_idx);
11607 } else {
11608 ggml_vk_soft_max(ctx, subctx&: compute_ctx, src0, src1, src2, dst: node);
11609 }
11610
11611 break;
11612 case GGML_OP_SOFT_MAX_BACK:
11613 ggml_vk_soft_max_back(ctx, subctx&: compute_ctx, src0, src1, dst: node);
11614
11615 break;
11616 case GGML_OP_ROPE:
11617 ggml_vk_rope(ctx, subctx&: compute_ctx, cgraph, node_idx, backprop: false);
11618
11619 break;
11620 case GGML_OP_ROPE_BACK:
11621 ggml_vk_rope(ctx, subctx&: compute_ctx, cgraph, node_idx, backprop: true);
11622
11623 break;
11624 case GGML_OP_ARGSORT:
11625 if (ctx->num_additional_fused_ops) {
11626 ggml_vk_topk_moe(ctx, subctx&: compute_ctx, cgraph, node_idx);
11627 } else {
11628 ggml_vk_argsort(ctx, subctx&: compute_ctx, src0, dst: node);
11629 }
11630
11631 break;
11632 case GGML_OP_SUM:
11633 ggml_vk_sum(ctx, subctx&: compute_ctx, src0, dst: node);
11634
11635 break;
11636 case GGML_OP_SUM_ROWS:
11637 ggml_vk_sum_rows(ctx, subctx&: compute_ctx, src0, dst: node);
11638
11639 break;
11640 case GGML_OP_MEAN:
11641 ggml_vk_mean(ctx, subctx&: compute_ctx, src0, dst: node);
11642
11643 break;
11644 case GGML_OP_ARGMAX:
11645 ggml_vk_argmax(ctx, subctx&: compute_ctx, src0, dst: node);
11646
11647 break;
11648 case GGML_OP_COUNT_EQUAL:
11649 ggml_vk_count_equal(ctx, subctx&: compute_ctx, src0, src1, dst: node);
11650
11651 break;
11652 case GGML_OP_IM2COL:
11653 ggml_vk_im2col(ctx, subctx&: compute_ctx, src0, src1, dst: node);
11654
11655 break;
11656 case GGML_OP_IM2COL_3D:
11657 ggml_vk_im2col_3d(ctx, subctx&: compute_ctx, src0, src1, dst: node);
11658
11659 break;
11660 case GGML_OP_TIMESTEP_EMBEDDING:
11661 ggml_vk_timestep_embedding(ctx, subctx&: compute_ctx, src0, dst: node);
11662
11663 break;
11664 case GGML_OP_CONV_TRANSPOSE_1D:
11665 ggml_vk_conv_transpose_1d(ctx, subctx&: compute_ctx, src0, src1, dst: node);
11666
11667 break;
11668 case GGML_OP_POOL_2D:
11669 ggml_vk_pool_2d(ctx, subctx&: compute_ctx, src0, dst: node);
11670
11671 break;
11672 case GGML_OP_CONV_2D:
11673 ggml_vk_conv_2d(ctx, subctx&: compute_ctx, src0, src1, dst: node);
11674
11675 break;
11676 case GGML_OP_CONV_TRANSPOSE_2D:
11677 ggml_vk_conv_transpose_2d(ctx, subctx&: compute_ctx, src0, src1, dst: node);
11678
11679 break;
11680 case GGML_OP_CONV_2D_DW:
11681 ggml_vk_conv_2d_dw(ctx, subctx&: compute_ctx, src0, src1, dst: node);
11682
11683 break;
11684 case GGML_OP_LEAKY_RELU:
11685 ggml_vk_leaky_relu(ctx, subctx&: compute_ctx, src0, dst: node);
11686
11687 break;
11688 case GGML_OP_MUL_MAT:
11689 ggml_vk_mul_mat(ctx, subctx&: compute_ctx, cgraph, node_idx);
11690
11691 break;
11692 case GGML_OP_MUL_MAT_ID:
11693 ggml_vk_mul_mat_id(ctx, subctx&: compute_ctx, cgraph, node_idx);
11694
11695 break;
11696
11697 case GGML_OP_FLASH_ATTN_EXT:
11698 ggml_vk_flash_attn(ctx, subctx&: compute_ctx, q: src0, k: src1, v: src2, mask: src3, sinks: node->src[4], dst: node);
11699
11700 break;
11701
11702 case GGML_OP_RWKV_WKV6:
11703 ggml_vk_rwkv_wkv6(ctx, subctx&: compute_ctx, dst: node);
11704
11705 break;
11706
11707 case GGML_OP_RWKV_WKV7:
11708 ggml_vk_rwkv_wkv7(ctx, subctx&: compute_ctx, dst: node);
11709
11710 break;
11711
11712 case GGML_OP_SSM_SCAN:
11713 ggml_vk_ssm_scan(ctx, subctx&: compute_ctx, dst: node);
11714
11715 break;
11716
11717 case GGML_OP_SSM_CONV:
11718 ggml_vk_ssm_conv(ctx, subctx&: compute_ctx, dst: node);
11719
11720 break;
11721
11722 case GGML_OP_OPT_STEP_ADAMW:
11723 ggml_vk_opt_step_adamw(ctx, subctx&: compute_ctx, dst: node);
11724
11725 break;
11726
11727 case GGML_OP_OPT_STEP_SGD:
11728 ggml_vk_opt_step_sgd(ctx, subctx&: compute_ctx, src0, src1, src2, dst: node);
11729
11730 break;
11731 default:
11732 return false;
11733 }
11734
11735 ctx->tensor_ctxs[node_idx] = compute_ctx;
11736
11737#if defined(GGML_VULKAN_CHECK_RESULTS)
11738 // Force context reset on each node so that each tensor ends up in its own context
11739 // and can be run and compared to its CPU equivalent separately
11740 last_node = true;
11741#endif
11742
11743 if (submit || last_node) {
11744 ggml_vk_ctx_end(ctx&: compute_ctx);
11745
11746 // TODO probably it'd be better to pass a exit_node flag to ggml_vk_compute_forward
11747 if (last_node) {
11748 compute_ctx->exit_tensor_idx = node_idx_begin;
11749 }
11750 else {
11751 compute_ctx->exit_tensor_idx = -1;
11752 }
11753
11754 ctx->compute_ctx.reset();
11755
11756 bool ok = ggml_vk_compute_forward(ctx, cgraph, tensor: node_begin, tensor_idx: node_idx_begin, use_fence: false, almost_ready);
11757 if (!ok) {
11758 if (node->op == GGML_OP_UNARY) {
11759 std::cerr << __func__ << ": error: op not supported UNARY " << node->name << " (" << ggml_unary_op_name(op: static_cast<ggml_unary_op>(node->op_params[0])) << ")" << std::endl;
11760 } else if (node->op == GGML_OP_GLU) {
11761 std::cerr << __func__ << ": error: op not supported GLU " << node->name << " (" << ggml_glu_op_name(op: static_cast<ggml_glu_op>(node->op_params[0])) << ")" << std::endl;
11762 } else {
11763 std::cerr << __func__ << ": error: op not supported " << node->name << " (" << ggml_op_name(op: node->op) << ")" << std::endl;
11764 }
11765 }
11766
11767 }
11768 return true;
11769}
11770
11771static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, ggml_tensor * tensor, int tensor_idx, bool use_fence = true, bool almost_ready = false) {
11772 GGML_UNUSED(cgraph);
11773 ggml_backend_buffer * buf = nullptr;
11774
11775 switch (tensor->op) {
11776 case GGML_OP_ADD:
11777 case GGML_OP_ACC:
11778 case GGML_OP_GET_ROWS:
11779 case GGML_OP_SUB:
11780 case GGML_OP_MUL:
11781 case GGML_OP_DIV:
11782 case GGML_OP_ADD_ID:
11783 case GGML_OP_CONCAT:
11784 case GGML_OP_UPSCALE:
11785 case GGML_OP_SCALE:
11786 case GGML_OP_SQR:
11787 case GGML_OP_SQRT:
11788 case GGML_OP_SIN:
11789 case GGML_OP_COS:
11790 case GGML_OP_CLAMP:
11791 case GGML_OP_PAD:
11792 case GGML_OP_ROLL:
11793 case GGML_OP_CPY:
11794 case GGML_OP_SET_ROWS:
11795 case GGML_OP_CONT:
11796 case GGML_OP_DUP:
11797 case GGML_OP_SILU_BACK:
11798 case GGML_OP_NORM:
11799 case GGML_OP_GROUP_NORM:
11800 case GGML_OP_RMS_NORM:
11801 case GGML_OP_RMS_NORM_BACK:
11802 case GGML_OP_L2_NORM:
11803 case GGML_OP_DIAG_MASK_INF:
11804 case GGML_OP_SOFT_MAX:
11805 case GGML_OP_SOFT_MAX_BACK:
11806 case GGML_OP_ROPE:
11807 case GGML_OP_ROPE_BACK:
11808 case GGML_OP_RESHAPE:
11809 case GGML_OP_VIEW:
11810 case GGML_OP_PERMUTE:
11811 case GGML_OP_TRANSPOSE:
11812 case GGML_OP_NONE:
11813 case GGML_OP_ARGSORT:
11814 case GGML_OP_SUM:
11815 case GGML_OP_SUM_ROWS:
11816 case GGML_OP_MEAN:
11817 case GGML_OP_ARGMAX:
11818 case GGML_OP_COUNT_EQUAL:
11819 case GGML_OP_IM2COL:
11820 case GGML_OP_IM2COL_3D:
11821 case GGML_OP_TIMESTEP_EMBEDDING:
11822 case GGML_OP_CONV_TRANSPOSE_1D:
11823 case GGML_OP_POOL_2D:
11824 case GGML_OP_CONV_2D:
11825 case GGML_OP_CONV_TRANSPOSE_2D:
11826 case GGML_OP_CONV_2D_DW:
11827 case GGML_OP_RWKV_WKV6:
11828 case GGML_OP_RWKV_WKV7:
11829 case GGML_OP_SSM_SCAN:
11830 case GGML_OP_SSM_CONV:
11831 case GGML_OP_LEAKY_RELU:
11832 case GGML_OP_REPEAT:
11833 case GGML_OP_REPEAT_BACK:
11834 case GGML_OP_OPT_STEP_ADAMW:
11835 case GGML_OP_OPT_STEP_SGD:
11836 buf = tensor->buffer;
11837 break;
11838 case GGML_OP_UNARY:
11839 switch (ggml_get_unary_op(tensor)) {
11840 case GGML_UNARY_OP_EXP:
11841 case GGML_UNARY_OP_SILU:
11842 case GGML_UNARY_OP_GELU:
11843 case GGML_UNARY_OP_GELU_ERF:
11844 case GGML_UNARY_OP_GELU_QUICK:
11845 case GGML_UNARY_OP_RELU:
11846 case GGML_UNARY_OP_TANH:
11847 case GGML_UNARY_OP_SIGMOID:
11848 case GGML_UNARY_OP_HARDSIGMOID:
11849 case GGML_UNARY_OP_HARDSWISH:
11850 buf = tensor->buffer;
11851 break;
11852 default:
11853 return false;
11854 }
11855 break;
11856 case GGML_OP_GLU:
11857 switch (ggml_get_glu_op(tensor)) {
11858 case GGML_GLU_OP_GEGLU:
11859 case GGML_GLU_OP_REGLU:
11860 case GGML_GLU_OP_SWIGLU:
11861 case GGML_GLU_OP_SWIGLU_OAI:
11862 case GGML_GLU_OP_GEGLU_ERF:
11863 case GGML_GLU_OP_GEGLU_QUICK:
11864 buf = tensor->buffer;
11865 break;
11866 default:
11867 return false;
11868 }
11869 break;
11870 case GGML_OP_MUL_MAT:
11871 case GGML_OP_MUL_MAT_ID:
11872 case GGML_OP_FLASH_ATTN_EXT:
11873 buf = tensor->buffer;
11874
11875 break;
11876 default:
11877 return false;
11878 }
11879
11880 if (buf == nullptr) {
11881 return false;
11882 }
11883
11884 VK_LOG_DEBUG("ggml_vk_compute_forward(" << tensor << ", name=" << tensor->name << ", op=" << ggml_op_name(tensor->op) << ", type=" << tensor->type << ", ne0=" << tensor->ne[0] << ", ne1=" << tensor->ne[1] << ", ne2=" << tensor->ne[2] << ", ne3=" << tensor->ne[3] << ", nb0=" << tensor->nb[0] << ", nb1=" << tensor->nb[1] << ", nb2=" << tensor->nb[2] << ", nb3=" << tensor->nb[3] << ", view_src=" << tensor->view_src << ", view_offs=" << tensor->view_offs << ")");
11885
11886 vk_context subctx = ctx->tensor_ctxs[tensor_idx].lock();
11887
11888 // always wait for the GPU work to be done for the last submit
11889 if (tensor_idx == subctx->exit_tensor_idx) {
11890 use_fence = true;
11891 }
11892
11893 // Only run if ctx hasn't been submitted yet
11894 if (!subctx->seqs.empty()) {
11895#ifdef GGML_VULKAN_CHECK_RESULTS
11896 ggml_vk_check_results_0(ctx, cgraph, tensor_idx);
11897 use_fence = true;
11898#endif
11899
11900 // Do staging buffer copies
11901 for (auto& cpy : subctx->in_memcpys) {
11902 memcpy(dest: cpy.dst, src: cpy.src, n: cpy.n);
11903 }
11904
11905 for (auto& mset : subctx->memsets) {
11906 memset(s: mset.dst, c: mset.val, n: mset.n);
11907 }
11908
11909 if (almost_ready && !ctx->almost_ready_fence_pending && !use_fence) {
11910 ggml_vk_submit(ctx&: subctx, fence: ctx->almost_ready_fence);
11911 ctx->almost_ready_fence_pending = true;
11912 } else {
11913 ggml_vk_submit(ctx&: subctx, fence: use_fence ? ctx->fence : vk::Fence{});
11914 }
11915
11916 if (use_fence) {
11917 ggml_vk_wait_for_fence(ctx);
11918 }
11919#ifdef GGML_VULKAN_CHECK_RESULTS
11920 ggml_vk_check_results_1(ctx, cgraph, tensor_idx);
11921#endif
11922 }
11923
11924 if (tensor_idx == subctx->exit_tensor_idx) {
11925 // Do staging buffer copies
11926 for (auto& cpy : subctx->out_memcpys) {
11927 memcpy(dest: cpy.dst, src: cpy.src, n: cpy.n);
11928 }
11929 subctx->in_memcpys.clear();
11930 subctx->out_memcpys.clear();
11931 subctx->memsets.clear();
11932 }
11933
11934 return true;
11935}
11936
11937// Clean up after graph processing is done
11938static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) {
11939 VK_LOG_DEBUG("ggml_vk_graph_cleanup()");
11940 ctx->prealloc_y_last_pipeline_used = {};
11941
11942 ctx->unsynced_nodes_written.clear();
11943 ctx->unsynced_nodes_read.clear();
11944 ctx->prealloc_x_need_sync = ctx->prealloc_y_need_sync = ctx->prealloc_split_k_need_sync = false;
11945
11946 ggml_vk_command_pool_cleanup(device&: ctx->device, p&: ctx->compute_cmd_pool);
11947 ggml_vk_command_pool_cleanup(device&: ctx->device, p&: ctx->transfer_cmd_pool);
11948
11949 for (size_t i = 0; i < ctx->gc.semaphores.size(); i++) {
11950 ctx->device->device.destroySemaphore(semaphore: { ctx->gc.semaphores[i].s });
11951 }
11952 ctx->gc.semaphores.clear();
11953
11954 for (size_t i = 0; i < ctx->gc.tl_semaphores.size(); i++) {
11955 ctx->device->device.destroySemaphore(semaphore: { ctx->gc.tl_semaphores[i].s });
11956 }
11957 ctx->gc.tl_semaphores.clear();
11958 ctx->semaphore_idx = 0;
11959
11960 ctx->event_idx = 0;
11961
11962 for (auto& event : ctx->gc.events) {
11963 ctx->device->device.resetEvent(event);
11964 }
11965
11966 ctx->tensor_ctxs.clear();
11967 ctx->gc.contexts.clear();
11968 ctx->pipeline_descriptor_set_requirements = 0;
11969 ctx->descriptor_set_idx = 0;
11970}
11971
11972// Clean up on backend free
11973static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) {
11974 VK_LOG_DEBUG("ggml_vk_cleanup(" << ctx->name << ")");
11975 ggml_vk_graph_cleanup(ctx);
11976
11977 ggml_vk_destroy_buffer(buf&: ctx->prealloc_x);
11978 ggml_vk_destroy_buffer(buf&: ctx->prealloc_y);
11979 ggml_vk_destroy_buffer(buf&: ctx->prealloc_split_k);
11980 ctx->prealloc_y_last_pipeline_used = nullptr;
11981
11982 ctx->prealloc_size_x = 0;
11983 ctx->prealloc_size_y = 0;
11984 ctx->prealloc_size_split_k = 0;
11985
11986 for (auto& event : ctx->gc.events) {
11987 ctx->device->device.destroyEvent(event);
11988 }
11989 ctx->gc.events.clear();
11990
11991 ctx->device->device.destroyFence(fence: ctx->fence);
11992 ctx->device->device.destroyFence(fence: ctx->almost_ready_fence);
11993
11994 for (auto& pool : ctx->descriptor_pools) {
11995 ctx->device->device.destroyDescriptorPool(descriptorPool: pool);
11996 }
11997 ctx->descriptor_pools.clear();
11998 ctx->descriptor_sets.clear();
11999
12000 ctx->compute_cmd_pool.destroy(device&: ctx->device->device);
12001 ctx->transfer_cmd_pool.destroy(device&: ctx->device->device);
12002}
12003
12004static int ggml_vk_get_device_count() {
12005 ggml_vk_instance_init();
12006
12007 return vk_instance.device_indices.size();
12008}
12009
12010static void ggml_vk_get_device_description(int device, char * description, size_t description_size) {
12011 ggml_vk_instance_init();
12012
12013 std::vector<vk::PhysicalDevice> devices = vk_instance.instance.enumeratePhysicalDevices();
12014
12015 vk::PhysicalDeviceProperties props;
12016 devices[device].getProperties(pProperties: &props);
12017
12018 snprintf(s: description, maxlen: description_size, format: "%s", props.deviceName.data());
12019}
12020
12021// backend interface
12022
12023#define UNUSED GGML_UNUSED
12024
12025// device backend
12026
12027static bool ggml_backend_buffer_is_vk(ggml_backend_buffer_t buffer) {
12028 return buffer->buft->iface.get_name == ggml_backend_vk_buffer_type_name;
12029}
12030
12031static void ggml_backend_vk_buffer_free_buffer(ggml_backend_buffer_t buffer) {
12032 VK_LOG_MEMORY("ggml_backend_vk_buffer_free_buffer()");
12033 ggml_backend_vk_buffer_context * ctx = (ggml_backend_vk_buffer_context *)buffer->context;
12034 ggml_vk_destroy_buffer(buf&: ctx->dev_buffer);
12035 delete ctx;
12036}
12037
12038static void * ggml_backend_vk_buffer_get_base(ggml_backend_buffer_t buffer) {
12039 return vk_ptr_base;
12040
12041 UNUSED(buffer);
12042}
12043
12044static enum ggml_status ggml_backend_vk_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
12045 VK_LOG_DEBUG("ggml_backend_vk_buffer_init_tensor(" << buffer << " (" << buffer->context << "), " << tensor << ")");
12046 if (tensor->view_src != nullptr) {
12047 GGML_ASSERT(tensor->view_src->buffer->buft == buffer->buft);
12048 }
12049 return GGML_STATUS_SUCCESS;
12050}
12051
12052static void ggml_backend_vk_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
12053 VK_LOG_DEBUG("ggml_backend_vk_buffer_memset_tensor(" << buffer << ", " << tensor << ", " << value << ", " << offset << ", " << size << ")");
12054 ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context;
12055 vk_buffer buf = buf_ctx->dev_buffer;
12056
12057 uint32_t val32 = (uint32_t)value * 0x01010101;
12058 ggml_vk_buffer_memset(dst&: buf, offset: vk_tensor_offset(tensor) + tensor->view_offs + offset, c: val32, size);
12059}
12060
12061static void ggml_backend_vk_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
12062 VK_LOG_DEBUG("ggml_backend_vk_buffer_set_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")");
12063 ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context;
12064 vk_buffer buf = buf_ctx->dev_buffer;
12065
12066 ggml_vk_buffer_write(dst&: buf, offset: vk_tensor_offset(tensor) + tensor->view_offs + offset, src: data, size);
12067}
12068
12069static void ggml_backend_vk_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
12070 VK_LOG_DEBUG("ggml_backend_vk_buffer_get_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")");
12071 ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context;
12072
12073 vk_buffer buf = buf_ctx->dev_buffer;
12074
12075 ggml_vk_buffer_read(src&: buf, offset: vk_tensor_offset(tensor) + tensor->view_offs + offset, dst: data, size);
12076}
12077
12078static bool ggml_backend_vk_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
12079 if (ggml_backend_buffer_is_vk(buffer: src->buffer)) {
12080 ggml_backend_vk_buffer_context * src_buf_ctx = (ggml_backend_vk_buffer_context *)src->buffer->context;
12081 ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
12082
12083 vk_buffer src_buf = src_buf_ctx->dev_buffer;
12084 vk_buffer dst_buf = dst_buf_ctx->dev_buffer;
12085
12086 ggml_vk_buffer_copy(dst&: dst_buf, dst_offset: vk_tensor_offset(tensor: dst) + dst->view_offs, src&: src_buf, src_offset: vk_tensor_offset(tensor: src) + src->view_offs, size: ggml_nbytes(tensor: src));
12087
12088 return true;
12089 }
12090 return false;
12091
12092 UNUSED(buffer);
12093}
12094
12095static void ggml_backend_vk_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
12096 ggml_backend_vk_buffer_context * ctx = (ggml_backend_vk_buffer_context *)buffer->context;
12097
12098 ggml_vk_buffer_memset(dst&: ctx->dev_buffer, offset: 0, c: value, size: buffer->size);
12099}
12100
12101static ggml_backend_buffer_i ggml_backend_vk_buffer_interface = {
12102 /* .free_buffer = */ ggml_backend_vk_buffer_free_buffer,
12103 /* .get_base = */ ggml_backend_vk_buffer_get_base,
12104 /* .init_tensor = */ ggml_backend_vk_buffer_init_tensor,
12105 /* .memset_tensor = */ ggml_backend_vk_buffer_memset_tensor,
12106 /* .set_tensor = */ ggml_backend_vk_buffer_set_tensor,
12107 /* .get_tensor = */ ggml_backend_vk_buffer_get_tensor,
12108 /* .cpy_tensor = */ ggml_backend_vk_buffer_cpy_tensor,
12109 /* .clear = */ ggml_backend_vk_buffer_clear,
12110 /* .reset = */ NULL,
12111};
12112
12113// vk buffer type
12114static const char * ggml_backend_vk_buffer_type_name(ggml_backend_buffer_type_t buft) {
12115 ggml_backend_vk_buffer_type_context * ctx = (ggml_backend_vk_buffer_type_context *)buft->context;
12116
12117 return ctx->name.c_str();
12118}
12119
12120static ggml_backend_buffer_t ggml_backend_vk_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
12121 VK_LOG_MEMORY("ggml_backend_vk_buffer_type_alloc_buffer(" << size << ")");
12122 ggml_backend_vk_buffer_type_context * ctx = (ggml_backend_vk_buffer_type_context *) buft->context;
12123
12124 vk_buffer dev_buffer = nullptr;
12125 try {
12126 dev_buffer = ggml_vk_create_buffer_device(device&: ctx->device, size);
12127 } catch (const vk::SystemError& e) {
12128 return nullptr;
12129 }
12130
12131 ggml_backend_vk_buffer_context * bufctx = new ggml_backend_vk_buffer_context(ctx->device, std::move(dev_buffer), ctx->name);
12132
12133 return ggml_backend_buffer_init(buft, iface: ggml_backend_vk_buffer_interface, context: bufctx, size);
12134}
12135
12136static size_t ggml_backend_vk_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
12137 ggml_backend_vk_buffer_type_context * ctx = (ggml_backend_vk_buffer_type_context *) buft->context;
12138 return ctx->device->properties.limits.minStorageBufferOffsetAlignment;
12139}
12140
12141static size_t ggml_backend_vk_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
12142 ggml_backend_vk_buffer_type_context * ctx = (ggml_backend_vk_buffer_type_context *) buft->context;
12143 return ctx->device->suballocation_block_size;
12144}
12145
12146static size_t ggml_backend_vk_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
12147 return ggml_nbytes(tensor);
12148
12149 UNUSED(buft);
12150}
12151
12152ggml_backend_buffer_type_t ggml_backend_vk_buffer_type(size_t dev_num) {
12153 ggml_vk_instance_init();
12154
12155 VK_LOG_DEBUG("ggml_backend_vk_buffer_type(" << dev_num << ")");
12156
12157 vk_device dev = ggml_vk_get_device(idx: dev_num);
12158
12159 return &dev->buffer_type;
12160}
12161
12162// host buffer type
12163
12164static const char * ggml_backend_vk_host_buffer_type_name(ggml_backend_buffer_type_t buft) {
12165 return GGML_VK_NAME "_Host";
12166
12167 UNUSED(buft);
12168}
12169
12170static const char * ggml_backend_vk_host_buffer_name(ggml_backend_buffer_t buffer) {
12171 return GGML_VK_NAME "_Host";
12172
12173 UNUSED(buffer);
12174}
12175
12176static void ggml_backend_vk_host_buffer_free_buffer(ggml_backend_buffer_t buffer) {
12177 VK_LOG_MEMORY("ggml_backend_vk_host_buffer_free_buffer()");
12178 ggml_vk_host_free(device&: vk_instance.devices[0], ptr: buffer->context);
12179}
12180
12181static ggml_backend_buffer_t ggml_backend_vk_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
12182 VK_LOG_MEMORY("ggml_backend_vk_host_buffer_type_alloc_buffer(" << size << ")");
12183
12184 size += 32; // Behave like the CPU buffer type
12185 void * ptr = nullptr;
12186 try {
12187 ptr = ggml_vk_host_malloc(device&: vk_instance.devices[0], size);
12188 } catch (vk::SystemError& e) {
12189 GGML_LOG_WARN("ggml_vulkan: Failed to allocate pinned memory (%s)\n", e.what());
12190 // fallback to cpu buffer
12191 return ggml_backend_buft_alloc_buffer(buft: ggml_backend_cpu_buffer_type(), size);
12192 }
12193
12194 ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(ptr, size);
12195 buffer->buft = buft;
12196 buffer->iface.free_buffer = ggml_backend_vk_host_buffer_free_buffer;
12197
12198 return buffer;
12199
12200 UNUSED(buft);
12201}
12202
12203static size_t ggml_backend_vk_host_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
12204 return vk_instance.devices[0]->properties.limits.minMemoryMapAlignment;
12205
12206 UNUSED(buft);
12207}
12208
12209static size_t ggml_backend_vk_host_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
12210 return vk_instance.devices[0]->suballocation_block_size;
12211
12212 UNUSED(buft);
12213}
12214
12215// Should be changed to return device-specific host buffer type
12216// but that probably requires changes in llama.cpp
12217ggml_backend_buffer_type_t ggml_backend_vk_host_buffer_type() {
12218 static struct ggml_backend_buffer_type ggml_backend_vk_buffer_type_host = {
12219 /* .iface = */ {
12220 /* .get_name = */ ggml_backend_vk_host_buffer_type_name,
12221 /* .alloc_buffer = */ ggml_backend_vk_host_buffer_type_alloc_buffer,
12222 /* .get_alignment = */ ggml_backend_vk_host_buffer_type_get_alignment,
12223 /* .get_max_size = */ ggml_backend_vk_host_buffer_type_get_max_size,
12224 /* .get_alloc_size = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size,
12225 /* .is_host = */ ggml_backend_cpu_buffer_type()->iface.is_host,
12226 },
12227 /* .device = */ ggml_backend_reg_dev_get(reg: ggml_backend_vk_reg(), index: 0),
12228 /* .context = */ nullptr,
12229 };
12230
12231 // Make sure device 0 is initialized
12232 ggml_vk_instance_init();
12233 ggml_vk_get_device(idx: 0);
12234
12235 return &ggml_backend_vk_buffer_type_host;
12236}
12237
12238
12239// backend
12240
12241static const char * ggml_backend_vk_name(ggml_backend_t backend) {
12242 ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
12243
12244 return ctx->name.c_str();
12245}
12246
12247static void ggml_backend_vk_free(ggml_backend_t backend) {
12248 ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
12249 VK_LOG_DEBUG("ggml_backend_vk_free(" << ctx->name << ")");
12250
12251 ggml_vk_cleanup(ctx);
12252
12253 delete ctx;
12254 delete backend;
12255}
12256
12257static ggml_backend_buffer_type_t ggml_backend_vk_get_default_buffer_type(ggml_backend_t backend) {
12258 ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
12259
12260 return &ctx->device->buffer_type;
12261}
12262
12263static void ggml_backend_vk_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
12264 VK_LOG_DEBUG("ggml_backend_vk_set_tensor_async(" << size << ")");
12265 ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
12266 GGML_ASSERT((tensor->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || tensor->buffer->buft == ggml_backend_vk_host_buffer_type()) && "unsupported buffer type");
12267
12268 ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context;
12269
12270 vk_context transfer_ctx;
12271
12272 if (ctx->transfer_ctx.expired()) {
12273 // Initialize new transfer context
12274 transfer_ctx = ggml_vk_create_context(ctx, p&: ctx->transfer_cmd_pool);
12275 ctx->transfer_ctx = transfer_ctx;
12276 ggml_vk_ctx_begin(device&: ctx->device, subctx&: transfer_ctx);
12277 } else {
12278 transfer_ctx = ctx->transfer_ctx.lock();
12279 }
12280
12281 vk_buffer buf = buf_ctx->dev_buffer;
12282
12283 ggml_vk_buffer_write_async(subctx: transfer_ctx, dst&: buf, offset: vk_tensor_offset(tensor) + tensor->view_offs + offset, src: data, size);
12284}
12285
12286static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
12287 VK_LOG_DEBUG("ggml_backend_vk_get_tensor_async(" << size << ")");
12288 ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
12289 GGML_ASSERT((tensor->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || tensor->buffer->buft == ggml_backend_vk_host_buffer_type()) && "unsupported buffer type");
12290
12291 ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context;
12292
12293 vk_context transfer_ctx;
12294
12295 if (ctx->transfer_ctx.expired()) {
12296 // Initialize new transfer context
12297 transfer_ctx = ggml_vk_create_context(ctx, p&: ctx->transfer_cmd_pool);
12298 ctx->transfer_ctx = transfer_ctx;
12299 ggml_vk_ctx_begin(device&: ctx->device, subctx&: transfer_ctx);
12300 } else {
12301 transfer_ctx = ctx->transfer_ctx.lock();
12302 }
12303
12304 vk_buffer buf = buf_ctx->dev_buffer;
12305
12306 ggml_vk_buffer_read_async(subctx: transfer_ctx, src&: buf, offset: vk_tensor_offset(tensor) + tensor->view_offs + offset, dst: data, size);
12307}
12308
12309static bool ggml_backend_vk_cpy_tensor_async(ggml_backend_t backend, const ggml_tensor * src, ggml_tensor * dst) {
12310 VK_LOG_DEBUG("ggml_backend_vk_cpy_tensor_async()");
12311 ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
12312 if ((dst->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || dst->buffer->buft == ggml_backend_vk_host_buffer_type()) && ggml_backend_buffer_is_vk(buffer: src->buffer)) {
12313 ggml_backend_vk_buffer_context * src_buf_ctx = (ggml_backend_vk_buffer_context *)src->buffer->context;
12314 ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
12315
12316 vk_context transfer_ctx;
12317
12318 if (ctx->transfer_ctx.expired()) {
12319 // Initialize new transfer context
12320 transfer_ctx = ggml_vk_create_context(ctx, p&: ctx->transfer_cmd_pool);
12321 ctx->transfer_ctx = transfer_ctx;
12322 ggml_vk_ctx_begin(device&: ctx->device, subctx&: transfer_ctx);
12323 } else {
12324 transfer_ctx = ctx->transfer_ctx.lock();
12325 }
12326
12327 vk_buffer src_buf = src_buf_ctx->dev_buffer;
12328 vk_buffer dst_buf = dst_buf_ctx->dev_buffer;
12329
12330 ggml_vk_buffer_copy_async(ctx&: transfer_ctx, dst&: dst_buf, dst_offset: vk_tensor_offset(tensor: dst) + dst->view_offs, src&: src_buf, src_offset: vk_tensor_offset(tensor: src) + src->view_offs, size: ggml_nbytes(tensor: src));
12331 return true;
12332 }
12333
12334 return false;
12335}
12336
12337static void ggml_backend_vk_synchronize(ggml_backend_t backend) {
12338 VK_LOG_DEBUG("ggml_backend_vk_synchronize()");
12339 ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
12340 if(ctx->transfer_ctx.expired()) {
12341 return;
12342 }
12343
12344 vk_context transfer_ctx = ctx->transfer_ctx.lock();
12345
12346 ggml_vk_ctx_end(ctx&: transfer_ctx);
12347
12348 for (auto& cpy : transfer_ctx->in_memcpys) {
12349 memcpy(dest: cpy.dst, src: cpy.src, n: cpy.n);
12350 }
12351
12352 ggml_vk_submit(ctx&: transfer_ctx, fence: ctx->fence);
12353 ggml_vk_wait_for_fence(ctx);
12354
12355 for (auto& cpy : transfer_ctx->out_memcpys) {
12356 memcpy(dest: cpy.dst, src: cpy.src, n: cpy.n);
12357 }
12358
12359 ctx->transfer_ctx.reset();
12360}
12361
12362static bool ggml_vk_is_empty(ggml_tensor * node) {
12363 return ggml_is_empty(tensor: node) || node->op == GGML_OP_NONE || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE;
12364}
12365
12366static bool ggml_vk_can_fuse(const ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops) {
12367 if (!ggml_can_fuse(cgraph, node_idx, ops)) {
12368 return false;
12369 }
12370
12371 if (ops.size() == 2 && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) {
12372 // additional constraints specific to this fusion
12373 const ggml_tensor *rms_norm = cgraph->nodes[node_idx];
12374 const ggml_tensor *mul = cgraph->nodes[node_idx + 1];
12375
12376 GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32);
12377 GGML_ASSERT(rms_norm->type == GGML_TYPE_F32);
12378 // rms_norm only supports f32
12379 if (mul->src[0]->type != GGML_TYPE_F32 ||
12380 mul->src[1]->type != GGML_TYPE_F32 ||
12381 mul->type != GGML_TYPE_F32) {
12382 return false;
12383 }
12384 // if rms_norm is the B operand, then we don't handle broadcast
12385 if (rms_norm == mul->src[1] &&
12386 !ggml_are_same_shape(t0: mul->src[0], t1: rms_norm)) {
12387 return false;
12388 }
12389 // rms_norm shader assumes contiguous rows
12390 if (!ggml_is_contiguous_rows(tensor: mul->src[0]) || !ggml_is_contiguous_rows(tensor: mul->src[1])) {
12391 return false;
12392 }
12393 }
12394 if (ops.size() == 2 && ops.begin()[0] == GGML_OP_MUL_MAT && ops.begin()[1] == GGML_OP_ADD) {
12395 // additional constraints specific to this fusion
12396 const ggml_tensor *mul = cgraph->nodes[node_idx];
12397 const ggml_tensor *add = cgraph->nodes[node_idx + 1];
12398 const ggml_tensor *bias = add->src[0] == mul ? add->src[1] : add->src[0];
12399
12400 // mat-vec only
12401 if (ggml_nrows(tensor: mul) != 1) {
12402 return false;
12403 }
12404 // shaders assume the types match
12405 if (mul->type != bias->type) {
12406 return false;
12407 }
12408 // shaders reuse the D shape for bias
12409 if (!ggml_are_same_shape(t0: mul, t1: bias) ||
12410 !ggml_are_same_stride(t0: mul, t1: bias)) {
12411 return false;
12412 }
12413 // unaligned bias isn't handled
12414 if (get_misalign_bytes(ctx, t: bias) != 0) {
12415 return false;
12416 }
12417 }
12418 if (ops.size() == 2 && ops.begin()[0] == GGML_OP_MUL_MAT_ID && ops.begin()[1] == GGML_OP_ADD_ID) {
12419 // additional constraints specific to this fusion
12420 const ggml_tensor *mul = cgraph->nodes[node_idx];
12421 const ggml_tensor *add = cgraph->nodes[node_idx + 1];
12422 const ggml_tensor *bias = add->src[1];
12423
12424 if (mul != add->src[0]) {
12425 return false;
12426 }
12427 // mat-vec only
12428 if (!ggml_vk_use_mul_mat_vec_id(cgraph, node_idx)) {
12429 return false;
12430 }
12431 // shaders assume the types match
12432 if (mul->type != bias->type) {
12433 return false;
12434 }
12435 // shaders assume the bias is contiguous
12436 if (!ggml_is_contiguous(tensor: bias)) {
12437 return false;
12438 }
12439 // the ID tensor must be the same for mul_mat_id and add_id
12440 if (mul->src[2] != add->src[2]) {
12441 return false;
12442 }
12443 // unaligned bias isn't handled
12444 if (get_misalign_bytes(ctx, t: bias) != 0) {
12445 return false;
12446 }
12447 }
12448
12449 return true;
12450}
12451
12452static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph,
12453 int node_idx, topk_moe_mode mode) {
12454
12455 const ggml_tensor * softmax;
12456 const ggml_tensor * weights;
12457
12458 switch (mode) {
12459 case TOPK_MOE_EARLY_SOFTMAX_NORM:
12460 softmax = cgraph->nodes[node_idx + 0];
12461 weights = cgraph->nodes[node_idx + 9];
12462 break;
12463 case TOPK_MOE_EARLY_SOFTMAX:
12464 softmax = cgraph->nodes[node_idx + 0];
12465 weights = cgraph->nodes[node_idx + 4];
12466 break;
12467 case TOPK_MOE_LATE_SOFTMAX:
12468 softmax = cgraph->nodes[node_idx + 4];
12469 weights = cgraph->nodes[node_idx + 5];
12470 break;
12471 default:
12472 return false;
12473 }
12474
12475 const float * op_params = (const float *)softmax->op_params;
12476
12477 float scale = op_params[0];
12478 float max_bias = op_params[1];
12479
12480 if (!ggml_is_contiguous(tensor: softmax->src[0]) || !ggml_is_contiguous(tensor: weights)) {
12481 return false;
12482 }
12483
12484 if (scale != 1.0f || max_bias != 0.0f) {
12485 return false;
12486 }
12487
12488 // don't fuse when masks or sinks are present
12489 if (softmax->src[1] || softmax->src[2]) {
12490 return false;
12491 }
12492
12493 const int n_expert = softmax->ne[0];
12494 // n_expert must be a power of 2
12495 if (!is_pow2(x: n_expert) || n_expert > (1 << (num_topk_moe_pipelines-1))) {
12496 return false;
12497 }
12498
12499 if (!ctx->device->subgroup_arithmetic ||
12500 !ctx->device->subgroup_shuffle ||
12501 !ctx->device->subgroup_require_full_support ||
12502 ctx->device->disable_fusion) {
12503 return false;
12504 }
12505
12506 return true;
12507}
12508
12509static bool ggml_vk_can_fuse_rope_set_rows(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph,
12510 int node_idx) {
12511 GGML_UNUSED(ctx);
12512 const ggml_tensor *rope = cgraph->nodes[node_idx + 0];
12513 const ggml_tensor *view = cgraph->nodes[node_idx + 1];
12514 const ggml_tensor *set_rows = cgraph->nodes[node_idx + 2];
12515
12516 // ne3 not tested
12517 if (rope->src[0]->ne[3] != 1) {
12518 return false;
12519 }
12520
12521 if (set_rows->type != GGML_TYPE_F32 && set_rows->type != GGML_TYPE_F16) {
12522 return false;
12523 }
12524
12525 if (set_rows->src[1]->type != GGML_TYPE_I64) {
12526 return false;
12527 }
12528
12529 // The view should flatten two dims of rope into one dim
12530 if (!ggml_is_contiguous(tensor: view) ||
12531 view->ne[0] != rope->ne[0] * rope->ne[1]) {
12532 return false;
12533 }
12534
12535 // Only norm/neox shaders have the fusion code
12536 const int mode = ((const int32_t *) rope->op_params)[2];
12537 if (mode != GGML_ROPE_TYPE_NORMAL && mode != GGML_ROPE_TYPE_NEOX) {
12538 return false;
12539 }
12540
12541 return true;
12542}
12543
12544// Check whether the tensors overlap in memory but are not equal.
12545// Fusions can potenitally overwrite src tensors in ways that are not prevented
12546// by ggml-alloc. If the fusion is entirely elementwise, then it's OK for them
12547// to overlap if they are exactly equal.
12548// XXX TODO this check is probably missing from several fusion optimizations.
12549static bool ggml_vk_tensors_overlap_but_not_equal(const ggml_tensor * a, const ggml_tensor * b) {
12550 ggml_backend_vk_buffer_context * a_buf_ctx = (ggml_backend_vk_buffer_context *)a->buffer->context;
12551 vk_buffer a_buf = a_buf_ctx->dev_buffer;
12552 ggml_backend_vk_buffer_context * b_buf_ctx = (ggml_backend_vk_buffer_context *)b->buffer->context;
12553 vk_buffer b_buf = b_buf_ctx->dev_buffer;
12554 if (a_buf == b_buf) {
12555 auto a_base = vk_tensor_offset(tensor: a) + a->view_offs;
12556 auto a_size = ggml_nbytes(tensor: a);
12557 auto b_base = vk_tensor_offset(tensor: b) + b->view_offs;
12558 auto b_size = ggml_nbytes(tensor: b);
12559
12560 if (a_base == b_base && a_size == b_size) {
12561 return false;
12562 }
12563
12564 if ((b_base <= a_base && a_base < b_base + b_size) ||
12565 (a_base <= b_base && b_base < a_base + a_size)) {
12566 return true;
12567 }
12568 }
12569 return false;
12570}
12571
12572static bool ggml_vk_can_fuse_rms_norm_mul_rope(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph,
12573 int node_idx) {
12574 GGML_UNUSED(ctx);
12575 const ggml_tensor *rms = cgraph->nodes[node_idx + 0];
12576 const ggml_tensor *mul = cgraph->nodes[node_idx + 1];
12577 const ggml_tensor *rope = cgraph->nodes[node_idx + 2];
12578
12579 const int mode = ((const int32_t *) rope->op_params)[2];
12580
12581 // noncontig tensors aren't tested, and don't seem common in practice
12582 if (!ggml_is_contiguous(tensor: rms) ||
12583 !ggml_is_contiguous(tensor: mul) ||
12584 !ggml_is_contiguous(tensor: rope)) {
12585 return false;
12586 }
12587
12588 // only norm/neox are handled in the shader
12589 if (mode != GGML_ROPE_TYPE_NEOX && mode != GGML_ROPE_TYPE_NORMAL) {
12590 return false;
12591 }
12592
12593 // shared memory size for passing data from mul->rope
12594 if (mul->ne[0] > 1024) {
12595 return false;
12596 }
12597
12598 // must not overwrite srcs in a way that's not elementwise
12599 ggml_tensor *other_src = mul->src[0] == rms ? mul->src[1] : mul->src[0];
12600 if (ggml_vk_tensors_overlap_but_not_equal(a: rms->src[0], b: rope) ||
12601 ggml_vk_tensors_overlap_but_not_equal(a: other_src, b: rope)) {
12602 return false;
12603 }
12604
12605 return true;
12606}
12607
12608static uint32_t ggml_vk_fuse_multi_add(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, int node_idx) {
12609
12610 const ggml_tensor *first_node = cgraph->nodes[node_idx];
12611 if (first_node->op != GGML_OP_ADD) {
12612 return 0;
12613 }
12614
12615 if (!ctx->device->multi_add) {
12616 return 0;
12617 }
12618
12619 int32_t num_adds = 1;
12620 while (node_idx + num_adds < cgraph->n_nodes &&
12621 cgraph->nodes[node_idx + num_adds]->op == GGML_OP_ADD &&
12622 num_adds < MAX_FUSED_ADDS) {
12623 num_adds++;
12624 }
12625
12626 // The shader currently requires same shapes (but different strides are allowed),
12627 // everything f32, and no misalignment
12628 for (int32_t i = 0; i < num_adds; ++i) {
12629 const ggml_tensor *next_node = cgraph->nodes[node_idx + i];
12630 if (!ggml_are_same_shape(t0: first_node, t1: next_node->src[0]) ||
12631 !ggml_are_same_shape(t0: first_node, t1: next_node->src[1]) ||
12632 next_node->type != GGML_TYPE_F32 ||
12633 next_node->src[0]->type != GGML_TYPE_F32 ||
12634 next_node->src[1]->type != GGML_TYPE_F32 ||
12635 get_misalign_bytes(ctx, t: next_node) ||
12636 get_misalign_bytes(ctx, t: next_node->src[0]) ||
12637 get_misalign_bytes(ctx, t: next_node->src[1])) {
12638 num_adds = i;
12639 }
12640 }
12641
12642 // Verify we can fuse these
12643 ggml_op adds[MAX_FUSED_ADDS];
12644 for (int32_t i = 0; i < num_adds; ++i) {
12645 adds[i] = GGML_OP_ADD;
12646 }
12647
12648 // decrease num_adds if they can't all be fused
12649 while (num_adds > 1 && !ggml_can_fuse(cgraph, node_idx, ops: adds, num_ops: num_adds)) {
12650 num_adds--;
12651 }
12652
12653 // a single add is not "fused", so just return zero
12654 if (num_adds == 1) {
12655 return 0;
12656 }
12657 return num_adds;
12658}
12659
12660static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
12661 VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)");
12662 ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
12663
12664 if (vk_instance.debug_utils_support) {
12665 vk::DebugUtilsLabelEXT dul = {};
12666 dul.pLabelName = "ggml_backend_vk_graph_compute";
12667 dul.color = std::array<float,4>{1.0f, 1.0f, 1.0f, 1.0f};
12668 vk_instance.pfn_vkQueueBeginDebugUtilsLabelEXT(ctx->device->compute_queue.queue, reinterpret_cast<VkDebugUtilsLabelEXT*>(&dul));
12669 }
12670
12671 ctx->prealloc_size_add_rms_partials_offset = 0;
12672 ctx->do_add_rms_partials = false;
12673 ctx->do_add_rms_partials_offset_calculation = false;
12674
12675 int last_node = cgraph->n_nodes - 1;
12676
12677 // If the last op in the cgraph isn't backend GPU, the command buffer doesn't get closed properly
12678 while (last_node > 0 && ggml_vk_is_empty(node: cgraph->nodes[last_node])) {
12679 last_node -= 1;
12680 }
12681
12682 // Reserve tensor context space for all nodes
12683 ctx->tensor_ctxs.resize(new_size: cgraph->n_nodes);
12684
12685 bool first_node_in_batch = true; // true if next node will be first node in a batch
12686 int submit_node_idx = 0; // index to first node in a batch
12687
12688 vk_context compute_ctx;
12689 if (vk_perf_logger_enabled) {
12690 // allocate/resize the query pool
12691 if (ctx->device->num_queries < cgraph->n_nodes + 1) {
12692 if (ctx->device->query_pool) {
12693 ctx->device->device.destroyQueryPool(queryPool: ctx->device->query_pool);
12694 }
12695 vk::QueryPoolCreateInfo query_create_info;
12696 query_create_info.queryType = vk::QueryType::eTimestamp;
12697 query_create_info.queryCount = cgraph->n_nodes + 100;
12698 ctx->device->query_pool = ctx->device->device.createQueryPool(createInfo: query_create_info);
12699 ctx->device->num_queries = query_create_info.queryCount;
12700 }
12701
12702 ctx->device->device.resetQueryPool(queryPool: ctx->device->query_pool, firstQuery: 0, queryCount: cgraph->n_nodes+1);
12703
12704 GGML_ASSERT(ctx->compute_ctx.expired());
12705 compute_ctx = ggml_vk_create_context(ctx, p&: ctx->compute_cmd_pool);
12706 ctx->compute_ctx = compute_ctx;
12707 ggml_vk_ctx_begin(device&: ctx->device, subctx&: compute_ctx);
12708 compute_ctx->s->buffer.writeTimestamp(pipelineStage: vk::PipelineStageFlagBits::eAllCommands, queryPool: ctx->device->query_pool, query: 0);
12709 }
12710
12711 ctx->prealloc_y_last_pipeline_used = nullptr;
12712 ctx->prealloc_y_last_tensor_used = nullptr;
12713
12714 if (ctx->prealloc_size_add_rms_partials) {
12715 ggml_vk_preallocate_buffers(ctx, subctx: nullptr);
12716 if (ctx->compute_ctx.expired()) {
12717 compute_ctx = ggml_vk_create_context(ctx, p&: ctx->compute_cmd_pool);
12718 ctx->compute_ctx = compute_ctx;
12719 ggml_vk_ctx_begin(device&: ctx->device, subctx&: compute_ctx);
12720 } else {
12721 compute_ctx = ctx->compute_ctx.lock();
12722 }
12723 // initialize partial sums to zero.
12724 ggml_vk_buffer_memset_async(ctx&: compute_ctx, dst&: ctx->prealloc_add_rms_partials, offset: 0, c: 0, size: ctx->prealloc_size_add_rms_partials);
12725 ggml_vk_sync_buffers(ctx, subctx&: compute_ctx);
12726 }
12727
12728 // Submit after enough work has accumulated, to overlap CPU cmdbuffer generation with GPU execution.
12729 // Estimate the amount of matmul work by looking at the weight matrix size, and submit every 100MB
12730 // (and scaled down based on model size, so smaller models submit earlier).
12731 // Also submit at least every 100 nodes, in case there are workloads without as much matmul.
12732 int nodes_per_submit = 100;
12733 int submitted_nodes = 0;
12734 int submit_count = 0;
12735 uint64_t mul_mat_bytes = 0;
12736 uint64_t total_mul_mat_bytes = 0;
12737 uint64_t mul_mat_bytes_per_submit = std::min(a: uint64_t(100*1000*1000), b: ctx->last_total_mul_mat_bytes / 40u);
12738 for (int i = 0; i < cgraph->n_nodes; i++) {
12739 if (first_node_in_batch) {
12740 submit_node_idx = i;
12741 }
12742
12743 if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) {
12744 auto bytes = ggml_nbytes(tensor: cgraph->nodes[i]->src[0]);
12745 mul_mat_bytes += bytes;
12746 total_mul_mat_bytes += bytes;
12747 }
12748
12749 if (!ctx->device->disable_fusion) {
12750 uint32_t num_adds = ggml_vk_fuse_multi_add(ctx, cgraph, node_idx: i);
12751 if (num_adds) {
12752 ctx->num_additional_fused_ops = num_adds - 1;
12753 } else if (ggml_vk_can_fuse(ctx, cgraph, node_idx: i, ops: { GGML_OP_MUL_MAT, GGML_OP_ADD })) {
12754 ctx->num_additional_fused_ops = 1;
12755 } else if (ggml_vk_can_fuse(ctx, cgraph, node_idx: i, ops: { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID })) {
12756 ctx->num_additional_fused_ops = 1;
12757 } else if (ggml_can_fuse_subgraph(cgraph, start_idx: i, ops: { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, outputs: { i + 4 }) &&
12758 ggml_check_edges(cgraph, start_idx: i, edges: rms_norm_mul_rope_view_set_rows_edges) &&
12759 ggml_vk_can_fuse_rms_norm_mul_rope(ctx, cgraph, node_idx: i) &&
12760 ggml_vk_can_fuse_rope_set_rows(ctx, cgraph, node_idx: i + 2)) {
12761 ctx->num_additional_fused_ops = 4;
12762 } else if (ggml_vk_can_fuse(ctx, cgraph, node_idx: i, ops: { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ROPE })&&
12763 ggml_vk_can_fuse_rms_norm_mul_rope(ctx, cgraph, node_idx: i)) {
12764 ctx->num_additional_fused_ops = 2;
12765 } else if (ggml_vk_can_fuse(ctx, cgraph, node_idx: i, ops: { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
12766 ctx->num_additional_fused_ops = 1;
12767 } else if (ggml_can_fuse_subgraph(cgraph, start_idx: i, ops: { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, outputs: { i + 2 }) &&
12768 ggml_check_edges(cgraph, start_idx: i, edges: rope_view_set_rows_edges) &&
12769 ggml_vk_can_fuse_rope_set_rows(ctx, cgraph, node_idx: i)) {
12770 ctx->num_additional_fused_ops = 2;
12771 } else if (ggml_can_fuse_subgraph(cgraph, start_idx: i, ops: topk_moe_early_softmax_norm, outputs: { i + 3, i + 9 }) &&
12772 ggml_check_edges(cgraph, start_idx: i, edges: topk_moe_early_softmax_norm_edges) &&
12773 ggml_vk_can_fuse_topk_moe(ctx, cgraph, node_idx: i, mode: TOPK_MOE_EARLY_SOFTMAX_NORM)) {
12774 ctx->num_additional_fused_ops = topk_moe_early_softmax_norm.size() - 1;
12775 // view of argsort writes to memory
12776 ctx->fused_ops_write_mask |= 1 << 3;
12777 } else if (ggml_can_fuse_subgraph(cgraph, start_idx: i, ops: topk_moe_early_softmax, outputs: { i + 3, i + 4 }) &&
12778 ggml_check_edges(cgraph, start_idx: i, edges: topk_moe_early_softmax_edges) &&
12779 ggml_vk_can_fuse_topk_moe(ctx, cgraph, node_idx: i, mode: TOPK_MOE_EARLY_SOFTMAX)) {
12780 ctx->num_additional_fused_ops = topk_moe_early_softmax.size() - 1;
12781 // view of argsort writes to memory
12782 ctx->fused_ops_write_mask |= 1 << 3;
12783 } else if (ggml_can_fuse_subgraph(cgraph, start_idx: i, ops: topk_moe_late_softmax, outputs: { i + 1, i + 5 }) &&
12784 ggml_check_edges(cgraph, start_idx: i, edges: topk_moe_late_softmax_edges) &&
12785 ggml_vk_can_fuse_topk_moe(ctx, cgraph, node_idx: i, mode: TOPK_MOE_LATE_SOFTMAX)) {
12786 ctx->num_additional_fused_ops = topk_moe_late_softmax.size() - 1;
12787 // view of argsort writes to memory
12788 ctx->fused_ops_write_mask |= 1 << 1;
12789 }
12790 }
12791 ctx->fused_ops_write_mask |= 1 << ctx->num_additional_fused_ops;
12792
12793 // Signal the almost_ready fence when the graph is mostly complete (< 20% remaining)
12794 bool almost_ready = (cgraph->n_nodes - i) < cgraph->n_nodes / 5;
12795 bool submit = (submitted_nodes >= nodes_per_submit) ||
12796 (mul_mat_bytes_per_submit != 0 && mul_mat_bytes >= mul_mat_bytes_per_submit) ||
12797 (i + ctx->num_additional_fused_ops >= last_node) ||
12798 (almost_ready && !ctx->almost_ready_fence_pending);
12799
12800 bool enqueued = ggml_vk_build_graph(ctx, cgraph, node_idx: i, node_begin: cgraph->nodes[submit_node_idx], node_idx_begin: submit_node_idx, last_node: i + ctx->num_additional_fused_ops >= last_node, almost_ready, submit);
12801
12802 if (vk_perf_logger_enabled) {
12803 if (ctx->compute_ctx.expired()) {
12804 compute_ctx = ggml_vk_create_context(ctx, p&: ctx->compute_cmd_pool);
12805 ctx->compute_ctx = compute_ctx;
12806 ggml_vk_ctx_begin(device&: ctx->device, subctx&: compute_ctx);
12807 } else {
12808 compute_ctx = ctx->compute_ctx.lock();
12809 }
12810 // If there are fused ops, just write out timestamps for all nodes to keep the accounting simple
12811 for (int j = 0; j < ctx->num_additional_fused_ops + 1; ++j) {
12812 compute_ctx->s->buffer.writeTimestamp(pipelineStage: vk::PipelineStageFlagBits::eAllCommands, queryPool: ctx->device->query_pool, query: i+j+1);
12813 }
12814 }
12815
12816 if (enqueued) {
12817 ++submitted_nodes;
12818
12819#ifndef GGML_VULKAN_CHECK_RESULTS
12820 if (first_node_in_batch) {
12821 first_node_in_batch = false;
12822 }
12823#endif
12824 }
12825
12826 if (submit && enqueued) {
12827 first_node_in_batch = true;
12828 submitted_nodes = 0;
12829 mul_mat_bytes = 0;
12830 if (submit_count < 3) {
12831 mul_mat_bytes_per_submit *= 2;
12832 }
12833 submit_count++;
12834 }
12835 i += ctx->num_additional_fused_ops;
12836 ctx->num_additional_fused_ops = 0;
12837 ctx->fused_ops_write_mask = 0;
12838 }
12839
12840 ctx->prealloc_size_add_rms_partials = std::max(a: ctx->prealloc_size_add_rms_partials, b: ctx->prealloc_size_add_rms_partials_offset);
12841 ctx->last_total_mul_mat_bytes = total_mul_mat_bytes;
12842
12843 if (vk_perf_logger_enabled) {
12844 // End the command buffer and submit/wait
12845 GGML_ASSERT(!ctx->compute_ctx.expired());
12846 compute_ctx = ctx->compute_ctx.lock();
12847 ggml_vk_ctx_end(ctx&: compute_ctx);
12848
12849 ggml_vk_submit(ctx&: compute_ctx, fence: ctx->device->fence);
12850 VK_CHECK(ctx->device->device.waitForFences({ ctx->device->fence }, true, UINT64_MAX), "GGML_VULKAN_PERF waitForFences");
12851 ctx->device->device.resetFences(fences: { ctx->device->fence });
12852
12853 // Get the results and pass them to the logger
12854 std::vector<uint64_t> timestamps(cgraph->n_nodes + 1);
12855 VK_CHECK(ctx->device->device.getQueryPoolResults(ctx->device->query_pool, 0, cgraph->n_nodes + 1, (cgraph->n_nodes + 1)*sizeof(uint64_t), timestamps.data(), sizeof(uint64_t), vk::QueryResultFlagBits::e64 | vk::QueryResultFlagBits::eWait), "get timestamp results");
12856 for (int i = 0; i < cgraph->n_nodes; i++) {
12857 if (!ggml_vk_is_empty(node: cgraph->nodes[i])) {
12858 ctx->device->perf_logger->log_timing(node: cgraph->nodes[i], time: uint64_t((timestamps[i+1] - timestamps[i]) * ctx->device->properties.limits.timestampPeriod));
12859 }
12860 }
12861
12862 ctx->device->perf_logger->print_timings();
12863 }
12864
12865 ggml_vk_graph_cleanup(ctx);
12866
12867 return GGML_STATUS_SUCCESS;
12868
12869 UNUSED(backend);
12870}
12871
12872// Sort the graph for improved parallelism.
12873static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * graph)
12874{
12875 VK_LOG_DEBUG("ggml_vk_graph_optimize(" << graph->n_nodes << " nodes)");
12876 ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
12877
12878 if (ctx->device->disable_graph_optimize) {
12879 return;
12880 }
12881
12882 auto const &is_empty = [](ggml_tensor * node) -> bool {
12883 return node->op == GGML_OP_NONE || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE;
12884 };
12885
12886 auto const &is_src_of = [](const ggml_tensor *dst, const ggml_tensor *src) -> bool {
12887 for (uint32_t s = 0; s < GGML_MAX_SRC; ++s) {
12888 if (dst->src[s] == src) {
12889 return true;
12890 }
12891 }
12892 // implicit dependency if they view the same tensor
12893 const ggml_tensor *dst2 = dst->view_src ? dst->view_src : dst;
12894 const ggml_tensor *src2 = src->view_src ? src->view_src : src;
12895 if (dst2 == src2) {
12896 return true;
12897 }
12898 return false;
12899 };
12900
12901 // This function tries to reorder the graph to allow nodes to run in parallel.
12902 // This helps with small batches, but for large batches its a slowdown, probably
12903 // due to cache contention. So only reorder if the majority of nodes have few rows.
12904 int num_small_nodes = 0;
12905 int num_counted_nodes = 0;
12906 for (int i = 0; i < graph->n_nodes; ++i) {
12907 if (!is_empty(graph->nodes[i]) &&
12908 graph->nodes[i]->op != GGML_OP_SET_ROWS) {
12909 if (ggml_nrows(tensor: graph->nodes[i]) <= 8) {
12910 num_small_nodes++;
12911 }
12912 num_counted_nodes++;
12913 }
12914 }
12915 if (num_small_nodes < num_counted_nodes / 2) {
12916 return;
12917 }
12918
12919 std::vector<ggml_tensor *> new_order;
12920 std::vector<bool> used(graph->n_nodes, false);
12921 int first_unused = 0;
12922 while (first_unused < graph->n_nodes) {
12923 std::vector<int> current_set;
12924
12925 // Check for fusion patterns and avoid reordering them
12926 auto const &match_pattern = [&](const std::initializer_list<ggml_op> &pattern, int start) -> bool {
12927 if (start + (int)pattern.size() <= graph->n_nodes) {
12928 bool is_pattern = true;
12929 for (size_t j = 0; j < pattern.size(); ++j) {
12930 if (graph->nodes[start + j]->op != pattern.begin()[j] || used[start + j]) {
12931 is_pattern = false;
12932 }
12933 }
12934 return is_pattern;
12935 }
12936 return false;
12937 };
12938
12939 auto const &keep_pattern = [&](const std::initializer_list<ggml_op> &pattern) -> bool {
12940 if (match_pattern(pattern, first_unused)) {
12941 for (size_t j = 0; j < pattern.size(); ++j) {
12942 new_order.push_back(x: graph->nodes[first_unused + j]);
12943 used[first_unused + j] = true;
12944 }
12945 while (first_unused < graph->n_nodes && used[first_unused]) {
12946 first_unused++;
12947 }
12948 return true;
12949 }
12950 return false;
12951 };
12952
12953 if (keep_pattern(topk_moe_early_softmax_norm)) {
12954 continue;
12955 }
12956 if (keep_pattern(topk_moe_early_softmax)) {
12957 continue;
12958 }
12959 if (keep_pattern(topk_moe_late_softmax)) {
12960 continue;
12961 }
12962
12963 // First, grab the next unused node.
12964 current_set.push_back(x: first_unused);
12965
12966 // Loop through the next N nodes. Grab any that don't depend on other nodes that
12967 // haven't already been run. Nodes that have already been run have used[i] set
12968 // to true. Allow nodes that depend on the previous node if it's a fusion pattern
12969 // that we support (e.g. RMS_NORM + MUL).
12970 // This first pass only grabs "real" (non-view nodes). Second pass grabs view nodes.
12971 // The goal is to not interleave real and view nodes in a way that breaks fusion.
12972 const int NUM_TO_CHECK = 20;
12973 for (int j = first_unused+1; j < std::min(a: first_unused + NUM_TO_CHECK, b: graph->n_nodes); ++j) {
12974 if (used[j]) {
12975 continue;
12976 }
12977 if (is_empty(graph->nodes[j])) {
12978 continue;
12979 }
12980 // Don't pull forward nodes from fusion patterns
12981 if (match_pattern(topk_moe_early_softmax_norm, j) ||
12982 match_pattern(topk_moe_early_softmax, j) ||
12983 match_pattern(topk_moe_late_softmax, j)) {
12984 continue;
12985 }
12986 bool ok = true;
12987 for (int c = first_unused; c < j; ++c) {
12988 if (!used[c] &&
12989 is_src_of(graph->nodes[j], graph->nodes[c]) &&
12990 !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_RMS_NORM && graph->nodes[j]->op == GGML_OP_MUL) &&
12991 !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT && graph->nodes[j]->op == GGML_OP_ADD) &&
12992 !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT_ID && graph->nodes[j]->op == GGML_OP_ADD_ID)) {
12993 ok = false;
12994 break;
12995 }
12996 }
12997 if (ok) {
12998 current_set.push_back(x: j);
12999
13000 int rope_idx = j;
13001
13002 // When we've found RMS_NORM + MUL, try to find a ROPE that uses it
13003 if (j > 0 &&
13004 graph->nodes[j]->op == GGML_OP_MUL &&
13005 graph->nodes[j-1]->op == GGML_OP_RMS_NORM) {
13006 for (int k = j + 1; k < std::min(a: j + 15, b: graph->n_nodes); ++k) {
13007 if (graph->nodes[k]->op == GGML_OP_ROPE &&
13008 graph->nodes[k]->src[0] == graph->nodes[j] &&
13009 // Check that other srcs are already valid
13010 graph->nodes[k]->src[1]->op == GGML_OP_NONE &&
13011 (graph->nodes[k]->src[2] == nullptr || graph->nodes[k]->src[2]->op == GGML_OP_NONE)) {
13012 rope_idx = k;
13013 current_set.push_back(x: rope_idx);
13014 used[rope_idx] = true;
13015 break;
13016 }
13017 }
13018 }
13019 // Look for ROPE + VIEW + SET_ROWS and make them consecutive
13020 if (graph->nodes[rope_idx]->op == GGML_OP_ROPE) {
13021 int view_idx = -1;
13022 int set_rows_idx = -1;
13023 for (int k = rope_idx+1; k < std::min(a: rope_idx + 10, b: graph->n_nodes); ++k) {
13024 if (view_idx == -1 &&
13025 graph->nodes[k]->op == GGML_OP_VIEW &&
13026 graph->nodes[k]->src[0] == graph->nodes[rope_idx]) {
13027 view_idx = k;
13028 continue;
13029 }
13030 if (view_idx != -1 &&
13031 set_rows_idx == -1 &&
13032 graph->nodes[k]->op == GGML_OP_SET_ROWS &&
13033 graph->nodes[k]->src[0] == graph->nodes[view_idx]) {
13034 set_rows_idx = k;
13035 break;
13036 }
13037 }
13038 if (set_rows_idx != -1) {
13039 current_set.push_back(x: view_idx);
13040 current_set.push_back(x: set_rows_idx);
13041 used[view_idx] = true;
13042 used[set_rows_idx] = true;
13043 }
13044 }
13045 }
13046 }
13047 // Second pass grabs view nodes.
13048 // Skip this if it would break a fusion optimization (don't split up add->rms_norm or add->add).
13049 if (graph->nodes[current_set.back()]->op != GGML_OP_ADD) {
13050 for (int j = first_unused+1; j < std::min(a: first_unused + NUM_TO_CHECK, b: graph->n_nodes); ++j) {
13051 if (used[j]) {
13052 continue;
13053 }
13054 if (!is_empty(graph->nodes[j])) {
13055 continue;
13056 }
13057 bool ok = true;
13058 for (int c = first_unused; c < j; ++c) {
13059 bool c_in_current_set = std::find(first: current_set.begin(), last: current_set.end(), val: c) != current_set.end();
13060 // skip views whose srcs haven't been processed.
13061 if (!used[c] &&
13062 is_src_of(graph->nodes[j], graph->nodes[c]) &&
13063 !c_in_current_set) {
13064 ok = false;
13065 break;
13066 }
13067 }
13068 if (ok) {
13069 current_set.push_back(x: j);
13070 }
13071 }
13072 }
13073
13074 // Push the current set into new_order
13075 for (auto c : current_set) {
13076 new_order.push_back(x: graph->nodes[c]);
13077 used[c] = true;
13078 }
13079 while (first_unused < graph->n_nodes && used[first_unused]) {
13080 first_unused++;
13081 }
13082 }
13083 // Replace the graph with the new order.
13084 for (int i = 0; i < graph->n_nodes; ++i) {
13085 graph->nodes[i] = new_order[i];
13086 }
13087}
13088
13089// TODO: enable async and synchronize
13090static ggml_backend_i ggml_backend_vk_interface = {
13091 /* .get_name = */ ggml_backend_vk_name,
13092 /* .free = */ ggml_backend_vk_free,
13093 /* .set_tensor_async = */ NULL, // ggml_backend_vk_set_tensor_async,
13094 /* .get_tensor_async = */ NULL, // ggml_backend_vk_get_tensor_async,
13095 /* .cpy_tensor_async = */ NULL, // ggml_backend_vk_cpy_tensor_async,
13096 /* .synchronize = */ NULL, // ggml_backend_vk_synchronize,
13097 /* .graph_plan_create = */ NULL,
13098 /* .graph_plan_free = */ NULL,
13099 /* .graph_plan_update = */ NULL,
13100 /* .graph_plan_compute = */ NULL,
13101 /* .graph_compute = */ ggml_backend_vk_graph_compute,
13102 /* .event_record = */ NULL,
13103 /* .event_wait = */ NULL,
13104 /* .graph_optimize = */ ggml_vk_graph_optimize,
13105};
13106
13107static ggml_guid_t ggml_backend_vk_guid() {
13108 static ggml_guid guid = { 0xb8, 0xf7, 0x4f, 0x86, 0x40, 0x3c, 0xe1, 0x02, 0x91, 0xc8, 0xdd, 0xe9, 0x02, 0x3f, 0xc0, 0x2b };
13109 return &guid;
13110}
13111
13112ggml_backend_t ggml_backend_vk_init(size_t dev_num) {
13113 VK_LOG_DEBUG("ggml_backend_vk_init(" << dev_num << ")");
13114
13115 ggml_backend_vk_context * ctx = new ggml_backend_vk_context;
13116 ggml_vk_init(ctx, idx: dev_num);
13117
13118 ggml_backend_t vk_backend = new ggml_backend {
13119 /* .guid = */ ggml_backend_vk_guid(),
13120 /* .iface = */ ggml_backend_vk_interface,
13121 /* .device = */ ggml_backend_reg_dev_get(reg: ggml_backend_vk_reg(), index: dev_num),
13122 /* .context = */ ctx,
13123 };
13124
13125 return vk_backend;
13126}
13127
13128bool ggml_backend_is_vk(ggml_backend_t backend) {
13129 return backend != NULL && ggml_guid_matches(guid_a: backend->guid, guid_b: ggml_backend_vk_guid());
13130}
13131
13132int ggml_backend_vk_get_device_count() {
13133 return ggml_vk_get_device_count();
13134}
13135
13136void ggml_backend_vk_get_device_description(int device, char * description, size_t description_size) {
13137 GGML_ASSERT(device < (int) vk_instance.device_indices.size());
13138 int dev_idx = vk_instance.device_indices[device];
13139 ggml_vk_get_device_description(device: dev_idx, description, description_size);
13140}
13141
13142void ggml_backend_vk_get_device_memory(int device, size_t * free, size_t * total) {
13143 GGML_ASSERT(device < (int) vk_instance.device_indices.size());
13144 GGML_ASSERT(device < (int) vk_instance.device_supports_membudget.size());
13145
13146 vk::PhysicalDevice vkdev = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[device]];
13147 vk::PhysicalDeviceMemoryBudgetPropertiesEXT budgetprops;
13148 vk::PhysicalDeviceMemoryProperties2 memprops = {};
13149 bool membudget_supported = vk_instance.device_supports_membudget[device];
13150
13151 if (membudget_supported) {
13152 memprops.pNext = &budgetprops;
13153 }
13154 vkdev.getMemoryProperties2(pMemoryProperties: &memprops);
13155
13156 for (uint32_t i = 0; i < memprops.memoryProperties.memoryHeapCount; ++i) {
13157 const vk::MemoryHeap & heap = memprops.memoryProperties.memoryHeaps[i];
13158
13159 if (heap.flags & vk::MemoryHeapFlagBits::eDeviceLocal) {
13160 *total = heap.size;
13161
13162 if (membudget_supported && i < budgetprops.heapUsage.size()) {
13163 *free = budgetprops.heapBudget[i] - budgetprops.heapUsage[i];
13164 } else {
13165 *free = heap.size;
13166 }
13167 break;
13168 }
13169 }
13170}
13171
13172static vk::PhysicalDeviceType ggml_backend_vk_get_device_type(int device_idx) {
13173 GGML_ASSERT(device_idx >= 0 && device_idx < (int) vk_instance.device_indices.size());
13174
13175 vk::PhysicalDevice device = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[device_idx]];
13176
13177 vk::PhysicalDeviceProperties2 props = {};
13178 device.getProperties2(pProperties: &props);
13179
13180 return props.properties.deviceType;
13181}
13182
13183static std::string ggml_backend_vk_get_device_pci_id(int device_idx) {
13184 GGML_ASSERT(device_idx >= 0 && device_idx < (int) vk_instance.device_indices.size());
13185
13186 vk::PhysicalDevice device = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[device_idx]];
13187
13188 const std::vector<vk::ExtensionProperties> ext_props = device.enumerateDeviceExtensionProperties();
13189
13190 bool ext_support = false;
13191
13192 for (const auto& properties : ext_props) {
13193 if (strcmp(s1: "VK_EXT_pci_bus_info", s2: properties.extensionName) == 0) {
13194 ext_support = true;
13195 break;
13196 }
13197 }
13198
13199 if (!ext_support) {
13200 return "";
13201 }
13202
13203 vk::PhysicalDeviceProperties2 props = {};
13204 vk::PhysicalDevicePCIBusInfoPropertiesEXT pci_bus_info = {};
13205
13206 props.pNext = &pci_bus_info;
13207
13208 device.getProperties2(pProperties: &props);
13209
13210 const uint32_t pci_domain = pci_bus_info.pciDomain;
13211 const uint32_t pci_bus = pci_bus_info.pciBus;
13212 const uint32_t pci_device = pci_bus_info.pciDevice;
13213 const uint8_t pci_function = (uint8_t) pci_bus_info.pciFunction; // pci function is between 0 and 7, prevent printf overflow warning
13214
13215 char pci_bus_id[16] = {};
13216 snprintf(s: pci_bus_id, maxlen: sizeof(pci_bus_id), format: "%04x:%02x:%02x.%x", pci_domain, pci_bus, pci_device, pci_function);
13217
13218 return std::string(pci_bus_id);
13219}
13220
13221//////////////////////////
13222
13223struct ggml_backend_vk_device_context {
13224 size_t device;
13225 std::string name;
13226 std::string description;
13227 bool is_integrated_gpu;
13228 std::string pci_bus_id;
13229};
13230
13231static const char * ggml_backend_vk_device_get_name(ggml_backend_dev_t dev) {
13232 ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
13233 return ctx->name.c_str();
13234}
13235
13236static const char * ggml_backend_vk_device_get_description(ggml_backend_dev_t dev) {
13237 ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
13238 return ctx->description.c_str();
13239}
13240
13241static void ggml_backend_vk_device_get_memory(ggml_backend_dev_t device, size_t * free, size_t * total) {
13242 ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)device->context;
13243 ggml_backend_vk_get_device_memory(device: ctx->device, free, total);
13244}
13245
13246static ggml_backend_buffer_type_t ggml_backend_vk_device_get_buffer_type(ggml_backend_dev_t dev) {
13247 ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
13248 return ggml_backend_vk_buffer_type(dev_num: ctx->device);
13249}
13250
13251static ggml_backend_buffer_type_t ggml_backend_vk_device_get_host_buffer_type(ggml_backend_dev_t dev) {
13252 UNUSED(dev);
13253 return ggml_backend_vk_host_buffer_type();
13254}
13255
13256static enum ggml_backend_dev_type ggml_backend_vk_device_get_type(ggml_backend_dev_t dev) {
13257 ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
13258
13259 return ctx->is_integrated_gpu ? GGML_BACKEND_DEVICE_TYPE_IGPU : GGML_BACKEND_DEVICE_TYPE_GPU;
13260}
13261
13262static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
13263 ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
13264
13265 props->name = ggml_backend_vk_device_get_name(dev);
13266 props->description = ggml_backend_vk_device_get_description(dev);
13267 props->type = ggml_backend_vk_device_get_type(dev);
13268 props->device_id = ctx->pci_bus_id.empty() ? nullptr : ctx->pci_bus_id.c_str();
13269 ggml_backend_vk_device_get_memory(device: dev, free: &props->memory_free, total: &props->memory_total);
13270 props->caps = {
13271 /* .async = */ false,
13272 /* .host_buffer = */ true,
13273 /* .buffer_from_host_ptr = */ false,
13274 /* .events = */ false,
13275 };
13276}
13277
13278static ggml_backend_t ggml_backend_vk_device_init(ggml_backend_dev_t dev, const char * params) {
13279 UNUSED(params);
13280 ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
13281 return ggml_backend_vk_init(dev_num: ctx->device);
13282}
13283
13284static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
13285 switch (op->op) {
13286 case GGML_OP_UNARY:
13287 switch (ggml_get_unary_op(tensor: op)) {
13288 case GGML_UNARY_OP_EXP:
13289 case GGML_UNARY_OP_GELU:
13290 case GGML_UNARY_OP_GELU_ERF:
13291 case GGML_UNARY_OP_GELU_QUICK:
13292 case GGML_UNARY_OP_SILU:
13293 case GGML_UNARY_OP_RELU:
13294 case GGML_UNARY_OP_TANH:
13295 case GGML_UNARY_OP_SIGMOID:
13296 case GGML_UNARY_OP_HARDSIGMOID:
13297 case GGML_UNARY_OP_HARDSWISH:
13298 return ggml_is_contiguous(tensor: op->src[0]) &&
13299 (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
13300 (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
13301 (op->src[0]->type == op->type);
13302 default:
13303 return false;
13304 }
13305 case GGML_OP_GLU:
13306 switch (ggml_get_glu_op(tensor: op)) {
13307 case GGML_GLU_OP_GEGLU:
13308 case GGML_GLU_OP_REGLU:
13309 case GGML_GLU_OP_SWIGLU:
13310 case GGML_GLU_OP_SWIGLU_OAI:
13311 case GGML_GLU_OP_GEGLU_ERF:
13312 case GGML_GLU_OP_GEGLU_QUICK:
13313 return ggml_is_contiguous(tensor: op->src[0]) &&
13314 (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
13315 (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
13316 (op->src[0]->type == op->type);
13317 default:
13318 return false;
13319 }
13320 case GGML_OP_MUL_MAT:
13321 case GGML_OP_MUL_MAT_ID:
13322 {
13323 ggml_type src0_type = op->src[0]->type;
13324 ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
13325 const vk_device& device = ggml_vk_get_device(idx: ctx->device);
13326 if (op->op == GGML_OP_MUL_MAT_ID) {
13327 if (!device->mul_mat_id_s[src0_type] && !device->mul_mat_id_m[src0_type] && !device->mul_mat_id_l[src0_type]) {
13328 // If there's not enough shared memory for row_ids and the result tile, fallback to CPU
13329 return false;
13330 }
13331 }
13332 switch (src0_type) {
13333 case GGML_TYPE_F32:
13334 case GGML_TYPE_F16:
13335 case GGML_TYPE_BF16:
13336 case GGML_TYPE_Q4_0:
13337 case GGML_TYPE_Q4_1:
13338 case GGML_TYPE_Q5_0:
13339 case GGML_TYPE_Q5_1:
13340 case GGML_TYPE_Q8_0:
13341 case GGML_TYPE_Q2_K:
13342 case GGML_TYPE_Q3_K:
13343 case GGML_TYPE_Q4_K:
13344 case GGML_TYPE_Q5_K:
13345 case GGML_TYPE_Q6_K:
13346 case GGML_TYPE_IQ1_S:
13347 case GGML_TYPE_IQ1_M:
13348 case GGML_TYPE_IQ2_XXS:
13349 case GGML_TYPE_IQ2_XS:
13350 case GGML_TYPE_IQ2_S:
13351 case GGML_TYPE_IQ3_XXS:
13352 case GGML_TYPE_IQ3_S:
13353 case GGML_TYPE_IQ4_XS:
13354 case GGML_TYPE_IQ4_NL:
13355 case GGML_TYPE_MXFP4:
13356 break;
13357 default:
13358 return false;
13359 }
13360 struct ggml_tensor * a;
13361 struct ggml_tensor * b;
13362 if (op->op == GGML_OP_MUL_MAT) {
13363 a = op->src[0];
13364 b = op->src[1];
13365 } else {
13366 a = op->src[2];
13367 b = op->src[1];
13368 }
13369 if (a->ne[3] != b->ne[3]) {
13370 return false;
13371 }
13372 if (!(ggml_vk_dim01_contiguous(tensor: op->src[0]) || op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_BF16) ||
13373 !(ggml_vk_dim01_contiguous(tensor: op->src[1]) || op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16)) {
13374 return false;
13375 }
13376 if (op->src[0]->type == GGML_TYPE_BF16 && op->src[1]->type == GGML_TYPE_F16) {
13377 // We currently don't have a bf16 x f16 shader, or an fp16->bf16 copy shader.
13378 // So don't support this combination for now.
13379 return false;
13380 }
13381
13382 return true;
13383 }
13384 case GGML_OP_FLASH_ATTN_EXT:
13385 {
13386 ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
13387 auto device = ggml_vk_get_device(idx: ctx->device);
13388 bool coopmat2 = device->coopmat2;
13389 uint32_t HSK = op->src[1]->ne[0];
13390 uint32_t HSV = op->src[2]->ne[0];
13391 if ((HSK % 8) != 0 || (HSV % 8) != 0) {
13392 return false;
13393 }
13394 if (op->src[4] && op->src[4]->type != GGML_TYPE_F32) {
13395 return false;
13396 }
13397 if (op->src[0]->type != GGML_TYPE_F32) {
13398 return false;
13399 }
13400 if (op->type != GGML_TYPE_F32) {
13401 return false;
13402 }
13403 if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) {
13404 return false;
13405 }
13406 // It's straightforward to support different K/V dequant, but would
13407 // significantly increase the number of pipelines
13408 if (op->src[1]->type != op->src[2]->type) {
13409 return false;
13410 }
13411 switch (op->src[1]->type) {
13412 case GGML_TYPE_F16:
13413 case GGML_TYPE_F32:
13414 case GGML_TYPE_Q4_0:
13415 case GGML_TYPE_Q8_0:
13416 // supported in scalar and coopmat2 paths
13417 break;
13418 case GGML_TYPE_Q4_1:
13419 case GGML_TYPE_Q5_0:
13420 case GGML_TYPE_Q5_1:
13421 // K dequants currently disabled because D dimension is rounded up to 256 and runs inefficiently
13422 //case GGML_TYPE_Q2_K:
13423 //case GGML_TYPE_Q3_K:
13424 //case GGML_TYPE_Q4_K:
13425 //case GGML_TYPE_Q5_K:
13426 //case GGML_TYPE_Q6_K:
13427 //case GGML_TYPE_IQ1_S:
13428 //case GGML_TYPE_IQ1_M:
13429 //case GGML_TYPE_IQ2_XXS:
13430 //case GGML_TYPE_IQ2_XS:
13431 //case GGML_TYPE_IQ2_S:
13432 //case GGML_TYPE_IQ3_XXS:
13433 //case GGML_TYPE_IQ3_S:
13434 //case GGML_TYPE_IQ4_XS:
13435 case GGML_TYPE_IQ4_NL:
13436 // currently supported only in coopmat2 path
13437 if (!coopmat2) {
13438 return false;
13439 }
13440 break;
13441 default:
13442 return false;
13443 }
13444 if (!coopmat2 && !device->subgroup_shuffle) {
13445 // scalar FA uses subgroupShuffle
13446 return false;
13447 }
13448 return true;
13449 }
13450 case GGML_OP_GET_ROWS:
13451 {
13452 switch (op->src[0]->type) {
13453 case GGML_TYPE_F32:
13454 case GGML_TYPE_F16:
13455 case GGML_TYPE_BF16:
13456 case GGML_TYPE_Q4_0:
13457 case GGML_TYPE_Q4_1:
13458 case GGML_TYPE_Q5_0:
13459 case GGML_TYPE_Q5_1:
13460 case GGML_TYPE_Q8_0:
13461 case GGML_TYPE_Q2_K:
13462 case GGML_TYPE_Q3_K:
13463 case GGML_TYPE_Q4_K:
13464 case GGML_TYPE_Q5_K:
13465 case GGML_TYPE_Q6_K:
13466 case GGML_TYPE_IQ1_S:
13467 case GGML_TYPE_IQ1_M:
13468 case GGML_TYPE_IQ2_XXS:
13469 case GGML_TYPE_IQ2_XS:
13470 case GGML_TYPE_IQ2_S:
13471 case GGML_TYPE_IQ3_XXS:
13472 case GGML_TYPE_IQ3_S:
13473 case GGML_TYPE_IQ4_XS:
13474 case GGML_TYPE_IQ4_NL:
13475 case GGML_TYPE_MXFP4:
13476 return true;
13477 default:
13478 return false;
13479 }
13480 }
13481 case GGML_OP_SET_ROWS:
13482 {
13483 switch (op->type) {
13484 case GGML_TYPE_F32:
13485 case GGML_TYPE_F16:
13486 case GGML_TYPE_BF16:
13487 case GGML_TYPE_Q4_0:
13488 case GGML_TYPE_Q4_1:
13489 case GGML_TYPE_Q5_0:
13490 case GGML_TYPE_Q5_1:
13491 case GGML_TYPE_Q8_0:
13492 case GGML_TYPE_IQ4_NL:
13493 return true;
13494 default:
13495 return false;
13496 }
13497 }
13498 case GGML_OP_CONT:
13499 case GGML_OP_CPY:
13500 case GGML_OP_DUP:
13501 {
13502 ggml_type src0_type = op->src[0]->type;
13503 ggml_type src1_type = op->src[1] != nullptr ? op->src[1]->type : src0_type;
13504
13505 if (src0_type == GGML_TYPE_F32) {
13506 switch (src1_type) {
13507 case GGML_TYPE_F32:
13508 case GGML_TYPE_F16:
13509 case GGML_TYPE_BF16:
13510 case GGML_TYPE_Q4_0:
13511 case GGML_TYPE_Q4_1:
13512 case GGML_TYPE_Q5_0:
13513 case GGML_TYPE_Q5_1:
13514 case GGML_TYPE_Q8_0:
13515 case GGML_TYPE_IQ4_NL:
13516 return true;
13517 default:
13518 break;
13519 }
13520 }
13521 if (src1_type == GGML_TYPE_F32) {
13522 switch (src0_type) {
13523 case GGML_TYPE_F16:
13524 case GGML_TYPE_Q4_0:
13525 case GGML_TYPE_Q4_1:
13526 case GGML_TYPE_Q5_0:
13527 case GGML_TYPE_Q5_1:
13528 case GGML_TYPE_Q8_0:
13529 case GGML_TYPE_IQ4_NL:
13530 return true;
13531 default:
13532 break;
13533 }
13534 }
13535
13536 if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
13537 return true;
13538 }
13539
13540 if (
13541 (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_I32) ||
13542 (src0_type == GGML_TYPE_I32 && src1_type == GGML_TYPE_F32)
13543 ) {
13544 return true;
13545 }
13546
13547 // We can handle copying from a type to the same type if it's
13548 // contiguous (memcpy). We use f16 or f32 shaders to do the copy,
13549 // so the type/block size must be a multiple of 4.
13550 if (src0_type == src1_type &&
13551 ggml_is_contiguous(tensor: op->src[0]) && ggml_is_contiguous(tensor: op) &&
13552 (ggml_type_size(type: src0_type) % 2) == 0) {
13553 return true;
13554 }
13555 return false;
13556 }
13557 case GGML_OP_REPEAT:
13558 return ggml_type_size(type: op->type) == sizeof(float) && ggml_type_size(type: op->src[0]->type) == sizeof(float);
13559 case GGML_OP_REPEAT_BACK:
13560 return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32;
13561 case GGML_OP_ROPE:
13562 case GGML_OP_ROPE_BACK:
13563 case GGML_OP_NONE:
13564 case GGML_OP_RESHAPE:
13565 case GGML_OP_VIEW:
13566 case GGML_OP_PERMUTE:
13567 case GGML_OP_TRANSPOSE:
13568 case GGML_OP_RMS_NORM:
13569 return true;
13570 case GGML_OP_NORM:
13571 case GGML_OP_GROUP_NORM:
13572 case GGML_OP_L2_NORM:
13573 return ggml_is_contiguous(tensor: op->src[0]);
13574 case GGML_OP_ADD:
13575 case GGML_OP_SUB:
13576 case GGML_OP_MUL:
13577 case GGML_OP_DIV:
13578 return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
13579 (op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16) &&
13580 (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16);
13581 case GGML_OP_ADD_ID:
13582 return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->src[2]->type == GGML_TYPE_I32 &&
13583 op->type == GGML_TYPE_F32;
13584 case GGML_OP_SILU_BACK:
13585 case GGML_OP_RMS_NORM_BACK:
13586 case GGML_OP_SQR:
13587 case GGML_OP_SQRT:
13588 case GGML_OP_SIN:
13589 case GGML_OP_COS:
13590 case GGML_OP_CLAMP:
13591 case GGML_OP_LEAKY_RELU:
13592 case GGML_OP_OPT_STEP_ADAMW:
13593 case GGML_OP_OPT_STEP_SGD:
13594 return op->src[0]->type == GGML_TYPE_F32;
13595 case GGML_OP_ARGSORT:
13596 return op->ne[0] <= max_argsort_cols;
13597 case GGML_OP_UPSCALE:
13598 case GGML_OP_ACC:
13599 case GGML_OP_CONCAT:
13600 case GGML_OP_SCALE:
13601 case GGML_OP_PAD:
13602 case GGML_OP_ROLL:
13603 case GGML_OP_DIAG_MASK_INF:
13604 case GGML_OP_SOFT_MAX:
13605 case GGML_OP_SOFT_MAX_BACK:
13606 return true;
13607 case GGML_OP_SUM:
13608 case GGML_OP_SUM_ROWS:
13609 case GGML_OP_MEAN:
13610 return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(tensor: op->src[0]);
13611 case GGML_OP_ARGMAX:
13612 case GGML_OP_COUNT_EQUAL:
13613 case GGML_OP_IM2COL:
13614 case GGML_OP_IM2COL_3D:
13615 case GGML_OP_TIMESTEP_EMBEDDING:
13616 case GGML_OP_CONV_2D_DW:
13617 case GGML_OP_POOL_2D:
13618 case GGML_OP_RWKV_WKV6:
13619 case GGML_OP_RWKV_WKV7:
13620 return true;
13621 case GGML_OP_SSM_SCAN:
13622 {
13623 for (int i = 0; i < 6; i++) {
13624 if (op->src[i] && ggml_is_quantized(type: op->src[i]->type)) {
13625 return false;
13626 }
13627 }
13628 if (op->src[6] && op->src[6]->type != GGML_TYPE_I32) {
13629 return false;
13630 }
13631 if (op->src[0]->type != GGML_TYPE_F32 || op->type != GGML_TYPE_F32) {
13632 return false;
13633 }
13634
13635 const uint32_t d_state = op->src[0]->ne[0];
13636 const uint32_t head_dim = op->src[0]->ne[1];
13637
13638 bool is_mamba2 = (op->src[3] && op->src[3]->nb[1] == sizeof(float));
13639 if (!is_mamba2) {
13640 return false;
13641 }
13642
13643 if ((d_state != 128 && d_state != 256) || head_dim % 16 != 0) {
13644 return false;
13645 }
13646
13647 ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
13648 const vk_device& device = ggml_vk_get_device(idx: ctx->device);
13649
13650 const uint32_t SPLIT_H = 16;
13651
13652 size_t stateC_size = SPLIT_H * d_state * sizeof(float);
13653
13654 if (stateC_size > device->properties.limits.maxComputeSharedMemorySize) {
13655 return false;
13656 }
13657
13658 return true;
13659 }
13660 case GGML_OP_SSM_CONV:
13661 return true;
13662 case GGML_OP_CONV_TRANSPOSE_1D:
13663 return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
13664 case GGML_OP_CONV_2D:
13665 case GGML_OP_CONV_TRANSPOSE_2D:
13666 {
13667 // Op is disabled for Apple because it segfaults at pipeline create time on MoltenVK
13668 ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
13669 const vk_device& device = ggml_vk_get_device(idx: ctx->device);
13670 if (op->op == GGML_OP_CONV_TRANSPOSE_2D &&
13671 device->properties.limits.maxPushConstantsSize < sizeof(vk_op_conv_transpose_2d_push_constants)) {
13672 return false;
13673 }
13674 // Channel-contiguous format is not supported yet.
13675 return ((op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
13676 op->src[1]->type == GGML_TYPE_F32 &&
13677 op->type == GGML_TYPE_F32 &&
13678 ggml_is_contiguous(tensor: op->src[0]) &&
13679 ggml_is_contiguous(tensor: op->src[1]) &&
13680 ggml_is_contiguous(tensor: op));
13681 }
13682 default:
13683 return false;
13684 }
13685
13686 UNUSED(dev);
13687}
13688
13689static bool ggml_backend_vk_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
13690 if (buft->iface.get_name != ggml_backend_vk_buffer_type_name) {
13691 return false;
13692 }
13693
13694 ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
13695 ggml_backend_vk_buffer_type_context * buft_ctx = (ggml_backend_vk_buffer_type_context *)buft->context;
13696
13697 return buft_ctx->device->idx == ctx->device;
13698}
13699
13700static bool ggml_backend_vk_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
13701 const int min_batch_size = 32;
13702
13703 return (op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS) ||
13704 (op->ne[2] >= min_batch_size && op->op == GGML_OP_MUL_MAT_ID);
13705
13706 UNUSED(dev);
13707}
13708
13709static const struct ggml_backend_device_i ggml_backend_vk_device_i = {
13710 /* .get_name = */ ggml_backend_vk_device_get_name,
13711 /* .get_description = */ ggml_backend_vk_device_get_description,
13712 /* .get_memory = */ ggml_backend_vk_device_get_memory,
13713 /* .get_type = */ ggml_backend_vk_device_get_type,
13714 /* .get_props = */ ggml_backend_vk_device_get_props,
13715 /* .init_backend = */ ggml_backend_vk_device_init,
13716 /* .get_buffer_type = */ ggml_backend_vk_device_get_buffer_type,
13717 /* .get_host_buffer_type = */ ggml_backend_vk_device_get_host_buffer_type,
13718 /* .buffer_from_host_ptr = */ NULL,
13719 /* .supports_op = */ ggml_backend_vk_device_supports_op,
13720 /* .supports_buft = */ ggml_backend_vk_device_supports_buft,
13721 /* .offload_op = */ ggml_backend_vk_device_offload_op,
13722 /* .event_new = */ NULL,
13723 /* .event_free = */ NULL,
13724 /* .event_synchronize = */ NULL,
13725};
13726
13727static const char * ggml_backend_vk_reg_get_name(ggml_backend_reg_t reg) {
13728 UNUSED(reg);
13729 return GGML_VK_NAME;
13730}
13731
13732static size_t ggml_backend_vk_reg_get_device_count(ggml_backend_reg_t reg) {
13733 UNUSED(reg);
13734 return ggml_backend_vk_get_device_count();
13735}
13736
13737static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, size_t device) {
13738 static std::vector<ggml_backend_dev_t> devices;
13739
13740 static bool initialized = false;
13741
13742 {
13743 static std::mutex mutex;
13744 std::lock_guard<std::mutex> lock(mutex);
13745 if (!initialized) {
13746 for (int i = 0; i < ggml_backend_vk_get_device_count(); i++) {
13747 ggml_backend_vk_device_context * ctx = new ggml_backend_vk_device_context;
13748 char desc[256];
13749 ggml_backend_vk_get_device_description(device: i, description: desc, description_size: sizeof(desc));
13750 ctx->device = i;
13751 ctx->name = GGML_VK_NAME + std::to_string(val: i);
13752 ctx->description = desc;
13753 ctx->is_integrated_gpu = ggml_backend_vk_get_device_type(device_idx: i) == vk::PhysicalDeviceType::eIntegratedGpu;
13754 ctx->pci_bus_id = ggml_backend_vk_get_device_pci_id(device_idx: i);
13755 devices.push_back(x: new ggml_backend_device {
13756 /* .iface = */ ggml_backend_vk_device_i,
13757 /* .reg = */ reg,
13758 /* .context = */ ctx,
13759 });
13760 }
13761 initialized = true;
13762 }
13763 }
13764
13765 GGML_ASSERT(device < devices.size());
13766 return devices[device];
13767}
13768
13769static const struct ggml_backend_reg_i ggml_backend_vk_reg_i = {
13770 /* .get_name = */ ggml_backend_vk_reg_get_name,
13771 /* .get_device_count = */ ggml_backend_vk_reg_get_device_count,
13772 /* .get_device = */ ggml_backend_vk_reg_get_device,
13773 /* .get_proc_address = */ NULL,
13774};
13775
13776ggml_backend_reg_t ggml_backend_vk_reg() {
13777 static ggml_backend_reg reg = {
13778 /* .api_version = */ GGML_BACKEND_API_VERSION,
13779 /* .iface = */ ggml_backend_vk_reg_i,
13780 /* .context = */ nullptr,
13781 };
13782 try {
13783 ggml_vk_instance_init();
13784 return &reg;
13785 } catch (const vk::SystemError& e) {
13786 VK_LOG_DEBUG("ggml_backend_vk_reg() -> Error: System error: " << e.what());
13787 return nullptr;
13788 } catch (const std::exception &e) {
13789 VK_LOG_DEBUG("ggml_backend_vk_reg() -> Error: " << e.what());
13790 return nullptr;
13791 } catch (...) {
13792 VK_LOG_DEBUG("ggml_backend_vk_reg() -> Error: unknown exception during Vulkan init");
13793 return nullptr;
13794 }
13795}
13796
13797// Extension availability
13798static bool ggml_vk_instance_validation_ext_available() {
13799#ifdef GGML_VULKAN_VALIDATE
13800 // Check if validation layer provides the extension
13801 const std::string layer_name = "VK_LAYER_KHRONOS_validation";
13802 for (const auto& layer : vk::enumerateInstanceLayerProperties()) {
13803 if (layer_name == layer.layerName.data()) {
13804 for (const auto& ext : vk::enumerateInstanceExtensionProperties(layer_name)) {
13805 if (strcmp("VK_EXT_validation_features", ext.extensionName.data()) == 0) {
13806 return true;
13807 }
13808 }
13809 }
13810 }
13811
13812 std::cerr << "ggml_vulkan: WARNING: Validation layer or layer extension VK_EXT_validation_features not found." << std::endl;
13813#endif
13814 return false;
13815}
13816static bool ggml_vk_instance_portability_enumeration_ext_available(const std::vector<vk::ExtensionProperties>& instance_extensions) {
13817#ifdef __APPLE__
13818 // Check for portability enumeration extension for MoltenVK support
13819 for (const auto& properties : instance_extensions) {
13820 if (strcmp("VK_KHR_portability_enumeration", properties.extensionName) == 0) {
13821 return true;
13822 }
13823 }
13824 std::cerr << "ggml_vulkan: WARNING: Instance extension VK_KHR_portability_enumeration not found." << std::endl;
13825#endif
13826 return false;
13827
13828 UNUSED(instance_extensions);
13829}
13830
13831// Extension availability
13832static bool ggml_vk_instance_debug_utils_ext_available(
13833 const std::vector<vk::ExtensionProperties> & instance_extensions) {
13834 // Check for portability enumeration extension for MoltenVK support
13835 for (const auto & properties : instance_extensions) {
13836 if (strcmp(s1: "VK_EXT_debug_utils", s2: properties.extensionName) == 0) {
13837 return true;
13838 }
13839 }
13840
13841 std::cerr << "ggml_vulkan: WARNING: Instance extension VK_EXT_debug_utils not found." << std::endl;
13842 return false;
13843
13844 UNUSED(instance_extensions);
13845}
13846
13847static bool ggml_vk_device_is_supported(const vk::PhysicalDevice & vkdev) {
13848 VkPhysicalDeviceFeatures2 device_features2;
13849 device_features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2;
13850
13851 VkPhysicalDeviceVulkan11Features vk11_features;
13852 vk11_features.pNext = nullptr;
13853 vk11_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_1_FEATURES;
13854 device_features2.pNext = &vk11_features;
13855
13856 vkGetPhysicalDeviceFeatures2(physicalDevice: vkdev, pFeatures: &device_features2);
13857
13858 return vk11_features.storageBuffer16BitAccess;
13859}
13860
13861static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props, vk_device_architecture arch) {
13862 switch (props.vendorID) {
13863 case VK_VENDOR_ID_INTEL:
13864 // Only allowing Xe2 GPU at the moment since Xe2 GPU can gain significant performance boost,
13865 // while some older hardware (ex. Arc A770) has performance regressions
13866 return arch == vk_device_architecture::INTEL_XE2;
13867 case VK_VENDOR_ID_AMD:
13868 if (driver_props.driverID == vk::DriverId::eAmdProprietary || driver_props.driverID == vk::DriverId::eAmdOpenSource) {
13869 // Workaround for AMD proprietary driver reporting support on all GPUs
13870 return arch == vk_device_architecture::AMD_RDNA3;
13871 }
13872 return true;
13873 default:
13874 return true;
13875 }
13876}
13877
13878// checks
13879
13880#ifdef GGML_VULKAN_CHECK_RESULTS
13881static void ggml_vk_print_graph_origin(const ggml_tensor * tensor, std::vector<const ggml_tensor *>& done, int level = 0) {
13882 if (std::find(done.begin(), done.end(), tensor) != done.end() || level > 10) {
13883 return;
13884 }
13885 for (int j = 0; j < level; j++) {
13886 std::cerr << " ";
13887 }
13888 std::cerr << ggml_op_name(tensor->op) << " gpu=" << (tensor->extra != nullptr) << std::endl;
13889
13890 done.push_back(tensor);
13891
13892 for (int i = 0; i < GGML_MAX_SRC; i++) {
13893 if (tensor->src[i] != nullptr) {
13894 ggml_vk_print_graph_origin(tensor->src[i], done, level + 1);
13895 }
13896 }
13897}
13898
13899static void ggml_vk_print_tensor_area(const ggml_tensor * tensor, const void * data, int i0, int i1, int i2, int i3) {
13900 if (tensor->type != GGML_TYPE_F32 && tensor->type != GGML_TYPE_F16 && tensor->type != GGML_TYPE_I32) {
13901 return;
13902 }
13903 i0 = std::max(i0, 5);
13904 i1 = std::max(i1, 5);
13905 i2 = std::max(i2, 0);
13906 i3 = std::max(i3, 0);
13907 fprintf(stderr, " ");
13908 for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) {
13909 fprintf(stderr, "%7d ", idx1);
13910 }
13911 fprintf(stderr, "\n");
13912 for (int idx0 = i0 - 5; idx0 < i0 + 5; idx0++) {
13913 fprintf(stderr, "%7d: ", idx0);
13914 for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) {
13915 if (idx0 >= 0 && idx0 < tensor->ne[0] && idx1 >= 0 && idx1 < tensor->ne[1] && i2 >= 0 && i2 < tensor->ne[2] && i3 >= 0 && i3 < tensor->ne[3]) {
13916 float val;
13917 if (tensor->type == GGML_TYPE_F32) {
13918 val = *(const float *) ((const char *) data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0]);
13919 } else if (tensor->type == GGML_TYPE_F16) {
13920 val = ggml_fp16_to_fp32(*(const ggml_fp16_t *) ((const char *) data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0]));
13921 } else if (tensor->type == GGML_TYPE_I32) {
13922 val = *(const int32_t *) ((const char *) data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0]);
13923 } else {
13924 GGML_ABORT("fatal error");
13925 }
13926 fprintf(stderr, "% 7.2f ", val);
13927 } else {
13928 fprintf(stderr, " ");
13929 }
13930 }
13931 fprintf(stderr, "\n");
13932 }
13933}
13934
13935static void ggml_vk_print_tensor(const ggml_tensor * tensor, const char * name) {
13936 void * tensor_data = tensor->data;
13937
13938 const bool is_gpu = tensor->buffer != nullptr && ggml_backend_buffer_is_vk(tensor->buffer);
13939
13940 if (is_gpu) {
13941 const size_t tensor_size = ggml_nbytes(tensor);
13942 tensor_data = malloc(tensor_size);
13943
13944 ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context;
13945
13946 vk_buffer buffer_gpu = buf_ctx->dev_buffer;
13947 ggml_vk_buffer_read(buffer_gpu, vk_tensor_offset(tensor) + tensor->view_offs, tensor_data, tensor_size);
13948 }
13949
13950 std::cerr << "TENSOR CHECK " << name << " (" << tensor->name << "): " << ggml_op_name(tensor->op) << std::endl;
13951 std::cerr << "tensor=" << tensor << " tensor->type: " << ggml_type_name(tensor->type) << " ne0=" << tensor->ne[0] << " nb0=" << tensor->nb[0] << " ne1=" << tensor->ne[1] << " nb1=" << tensor->nb[1] << " ne2=" << tensor->ne[2] << " nb2=" << tensor->nb[2] << " ne3=" << tensor->ne[3] << " nb3=" << tensor->nb[3] << std::endl;
13952 if (tensor->src[0] != nullptr) {
13953 std::cerr << "tensor->src[0]=" << tensor->src[0] << " name=" << tensor->src[0]->name << " op=" << ggml_op_name(tensor->src[0]->op) << " type=" << ggml_type_name(tensor->src[0]->type) << " ne0=" << tensor->src[0]->ne[0] << " nb0=" << tensor->src[0]->nb[0] << " ne1=" << tensor->src[0]->ne[1] << " nb1=" << tensor->src[0]->nb[1] << " ne2=" << tensor->src[0]->ne[2] << " nb2=" << tensor->src[0]->nb[2] << " ne3=" << tensor->src[0]->ne[3] << " nb3=" << tensor->src[0]->nb[3] << std::endl;
13954 }
13955 if (tensor->src[1] != nullptr) {
13956 std::cerr << "tensor->src[1]=" << tensor->src[1] << " name=" << tensor->src[1]->name << " op=" << ggml_op_name(tensor->src[1]->op) << " type=" << ggml_type_name(tensor->src[1]->type) << " ne0=" << tensor->src[1]->ne[0] << " nb0=" << tensor->src[1]->nb[0] << " ne1=" << tensor->src[1]->ne[1] << " nb1=" << tensor->src[1]->nb[1] << " ne2=" << tensor->src[1]->ne[2] << " nb2=" << tensor->src[1]->nb[2] << " ne3=" << tensor->src[1]->ne[3] << " nb3=" << tensor->src[1]->nb[3] << std::endl;
13957 }
13958 std::cerr << std::endl << "Result:" << std::endl;
13959 ggml_vk_print_tensor_area(tensor, tensor_data, 5, 5, 0, 0);
13960 std::cerr << std::endl;
13961 std::vector<const ggml_tensor *> done;
13962 ggml_vk_print_graph_origin(tensor, done);
13963
13964 if (is_gpu) {
13965 free(tensor_data);
13966 }
13967}
13968
13969void * comp_result;
13970size_t comp_size;
13971size_t comp_nb[GGML_MAX_DIMS];
13972size_t check_counter = 0;
13973static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx) {
13974 ggml_tensor * tensor = cgraph->nodes[tensor_idx + ctx->num_additional_fused_ops];
13975 if (tensor->op == GGML_OP_TRANSPOSE || tensor->op == GGML_OP_SET_ROWS) {
13976 return;
13977 }
13978
13979 check_counter++;
13980 if (!(vk_output_tensor > 0 && vk_output_tensor == check_counter) && check_counter <= vk_skip_checks) {
13981 return;
13982 }
13983
13984 VK_LOG_DEBUG("ggml_vk_check_results_0(" << tensor->name << ")");
13985
13986 struct ggml_init_params iparams = {
13987 /*.mem_size =*/ 2ul*1024ul*1024ul*1024ul,
13988 /*.mem_buffer =*/ NULL,
13989 /*.no_alloc =*/ false,
13990 };
13991
13992 struct ggml_context * ggml_ctx = ggml_init(iparams);
13993
13994 std::array<struct ggml_tensor *, GGML_MAX_SRC> src_clone = {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr};
13995 const char * srci_name[GGML_MAX_SRC] = {"src0", "src1", "src2", "src3", "src4", "src5", "src6", "src7", "src8", "src9"};
13996
13997 std::map<ggml_tensor *, ggml_tensor *> cloned_tensors;
13998 std::vector<void *> cloned_mallocs;
13999
14000 struct ggml_tensor * tensor_clone = nullptr;
14001
14002 for (int f = 0; f < ctx->num_additional_fused_ops + 1; ++f) {
14003 tensor = cgraph->nodes[tensor_idx + f];
14004 for (int i = 0; i < GGML_MAX_SRC; i++) {
14005 ggml_tensor * srci = tensor->src[i];
14006 if (srci == nullptr) {
14007 continue;
14008 }
14009 // If a src tensor has been cloned, use that one
14010 auto it = cloned_tensors.find(srci);
14011 if (it != cloned_tensors.end()) {
14012 src_clone[i] = it->second;
14013 continue;
14014 }
14015 ggml_tensor * srci_clone = ggml_dup_tensor(ggml_ctx, srci);
14016 size_t srci_size = ggml_nbytes(srci);
14017
14018 src_clone[i] = srci_clone;
14019 void *src_buffer = malloc(srci_size);
14020 cloned_mallocs.push_back(src_buffer);
14021
14022 srci_clone->data = src_buffer;
14023 if (ggml_backend_buffer_is_host(srci->buffer)) {
14024 memcpy(srci_clone->data, srci->data, srci_size);
14025 memcpy(srci_clone->nb, srci->nb, sizeof(size_t) * GGML_MAX_DIMS);
14026 } else if (ggml_backend_buffer_is_vk(srci->buffer)) {
14027 ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)srci->buffer->context;
14028 vk_buffer& buffer_gpu = buf_ctx->dev_buffer;
14029 uint64_t offset = vk_tensor_offset(srci) + srci->view_offs;
14030 if (!ggml_is_contiguous(srci) && ggml_vk_dim01_contiguous(srci)) {
14031 for (int i3 = 0; i3 < srci->ne[3]; i3++) {
14032 for (int i2 = 0; i2 < srci->ne[2]; i2++) {
14033 const int idx = i3*srci->ne[2] + i2;
14034 ggml_vk_buffer_read(buffer_gpu, offset + idx * srci->nb[2], ((char *)srci_clone->data + idx * srci_clone->nb[2]), srci->ne[1] * srci->nb[1]);
14035 }
14036 }
14037
14038 srci_clone->nb[0] = srci->nb[0];
14039 srci_clone->nb[1] = srci->nb[1];
14040 for (int i = 2; i < GGML_MAX_DIMS; i++) {
14041 srci_clone->nb[i] = srci_clone->nb[i - 1]*srci_clone->ne[i - 1];
14042 }
14043 } else {
14044 if (offset + srci_size >= buffer_gpu->size) {
14045 srci_size = buffer_gpu->size - offset;
14046 }
14047 ggml_vk_buffer_read(buffer_gpu, offset, srci_clone->data, srci_size);
14048 memcpy(srci_clone->nb, srci->nb, sizeof(size_t) * GGML_MAX_DIMS);
14049 }
14050 } else {
14051 GGML_ABORT("fatal error");
14052 }
14053
14054 if (vk_output_tensor > 0 && vk_output_tensor == check_counter) {
14055 ggml_vk_print_tensor(srci, srci_name[i]);
14056 }
14057 }
14058
14059 if (tensor->op == GGML_OP_FLASH_ATTN_EXT) {
14060 const float * params = (const float *)tensor->op_params;
14061 tensor_clone = ggml_flash_attn_ext(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3], params[0], params[1], params[2]);
14062 if (src_clone[4]) {
14063 ggml_flash_attn_ext_add_sinks(tensor_clone, src_clone[4]);
14064 }
14065 } else if (tensor->op == GGML_OP_MUL_MAT) {
14066 tensor_clone = ggml_mul_mat(ggml_ctx, src_clone[0], src_clone[1]);
14067 } else if (tensor->op == GGML_OP_MUL_MAT_ID) {
14068 tensor_clone = ggml_mul_mat_id(ggml_ctx, src_clone[0], src_clone[1], src_clone[2]);
14069 } else if (tensor->op == GGML_OP_SUB) {
14070 tensor_clone = ggml_sub(ggml_ctx, src_clone[0], src_clone[1]);
14071 } else if (tensor->op == GGML_OP_MUL) {
14072 tensor_clone = ggml_mul(ggml_ctx, src_clone[0], src_clone[1]);
14073 } else if (tensor->op == GGML_OP_DIV) {
14074 tensor_clone = ggml_div(ggml_ctx, src_clone[0], src_clone[1]);
14075 } else if (tensor->op == GGML_OP_CONCAT) {
14076 tensor_clone = ggml_concat(ggml_ctx, src_clone[0], src_clone[1], *(int *)tensor->op_params);
14077 } else if (tensor->op == GGML_OP_UPSCALE) {
14078 tensor_clone = ggml_interpolate(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], (ggml_scale_mode) tensor->op_params[0]);
14079 } else if (tensor->op == GGML_OP_SCALE) {
14080 const float * params = (const float *)tensor->op_params;
14081 tensor_clone = ggml_scale_bias(ggml_ctx, src_clone[0], params[0], params[1]);
14082 } else if (tensor->op == GGML_OP_SQR) {
14083 tensor_clone = ggml_sqr(ggml_ctx, src_clone[0]);
14084 } else if (tensor->op == GGML_OP_SQRT) {
14085 tensor_clone = ggml_sqrt(ggml_ctx, src_clone[0]);
14086 } else if (tensor->op == GGML_OP_SIN) {
14087 tensor_clone = ggml_sin(ggml_ctx, src_clone[0]);
14088 } else if (tensor->op == GGML_OP_COS) {
14089 tensor_clone = ggml_cos(ggml_ctx, src_clone[0]);
14090 } else if (tensor->op == GGML_OP_CLAMP) {
14091 const float * params = (const float *)tensor->op_params;
14092 tensor_clone = ggml_clamp(ggml_ctx, src_clone[0], params[0], params[1]);
14093 } else if (tensor->op == GGML_OP_PAD) {
14094 tensor_clone = ggml_pad_ext(ggml_ctx, src_clone[0], tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3],
14095 tensor->op_params[4], tensor->op_params[5], tensor->op_params[6], tensor->op_params[7]);
14096 } else if (tensor->op == GGML_OP_REPEAT) {
14097 tensor_clone = ggml_repeat(ggml_ctx, src_clone[0], tensor);
14098 } else if (tensor->op == GGML_OP_REPEAT_BACK) {
14099 tensor_clone = ggml_repeat_back(ggml_ctx, src_clone[0], tensor);
14100 } else if (tensor->op == GGML_OP_ADD) {
14101 tensor_clone = ggml_add(ggml_ctx, src_clone[0], src_clone[1]);
14102 } else if (tensor->op == GGML_OP_ACC) {
14103 tensor_clone = ggml_acc(ggml_ctx, src_clone[0], src_clone[1], tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3]);
14104 } else if (tensor->op == GGML_OP_NORM) {
14105 tensor_clone = ggml_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params);
14106 } else if (tensor->op == GGML_OP_GROUP_NORM) {
14107 const float * float_params = (const float *)tensor->op_params;
14108 tensor_clone = ggml_group_norm(ggml_ctx, src_clone[0], tensor->op_params[0], float_params[1]);
14109 } else if (tensor->op == GGML_OP_RMS_NORM) {
14110 tensor_clone = ggml_rms_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params);
14111 } else if (tensor->op == GGML_OP_RMS_NORM_BACK) {
14112 const float eps = ((float *) tensor->op_params)[0];
14113 tensor_clone = ggml_rms_norm_back(ggml_ctx, src_clone[0], src_clone[1], eps);
14114 } else if (tensor->op == GGML_OP_SILU_BACK) {
14115 tensor_clone = ggml_silu_back(ggml_ctx, src_clone[0], src_clone[1]);
14116 } else if (tensor->op == GGML_OP_L2_NORM) {
14117 const float eps = ((float *) tensor->op_params)[0];
14118 tensor_clone = ggml_l2_norm(ggml_ctx, src_clone[0], eps);
14119 } else if (tensor->op == GGML_OP_SOFT_MAX) {
14120 if (tensor->src[1] != nullptr) {
14121 const float * params = (const float *)tensor->op_params;
14122 tensor_clone = ggml_soft_max_ext(ggml_ctx, src_clone[0], src_clone[1], params[0], params[1]);
14123 } else {
14124 tensor_clone = ggml_soft_max(ggml_ctx, src_clone[0]);
14125 }
14126 } else if (tensor->op == GGML_OP_SOFT_MAX_BACK) {
14127 tensor_clone = ggml_soft_max_ext_back(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
14128 } else if (tensor->op == GGML_OP_DIAG_MASK_INF) {
14129 tensor_clone = ggml_diag_mask_inf(ggml_ctx, src_clone[0], tensor->op_params[0]);
14130 } else if (tensor->op == GGML_OP_ROPE || tensor->op == GGML_OP_ROPE_BACK) {
14131 const int n_dims = ((int32_t *) tensor->op_params)[1];
14132 const int mode = ((int32_t *) tensor->op_params)[2];
14133 //const int n_ctx_ggml = ((int32_t *) tensor->op_params)[3];
14134 const int n_ctx_orig_ggml = ((int32_t *) tensor->op_params)[4];
14135 const float freq_base = ((float *) tensor->op_params)[5];
14136 const float freq_scale = ((float *) tensor->op_params)[6];
14137 const float ext_factor = ((float *) tensor->op_params)[7];
14138 const float attn_factor = ((float *) tensor->op_params)[8];
14139 const float beta_fast = ((float *) tensor->op_params)[9];
14140 const float beta_slow = ((float *) tensor->op_params)[10];
14141 if (mode & GGML_ROPE_TYPE_MROPE) {
14142 int32_t *sections = ((int32_t *) tensor->op_params) + 11;
14143 if (tensor->op == GGML_OP_ROPE) {
14144 tensor_clone = ggml_rope_multi(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, sections, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
14145 } else {
14146 tensor_clone = ggml_rope_multi_back(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, sections, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
14147 }
14148 } else {
14149 if (tensor->op == GGML_OP_ROPE) {
14150 tensor_clone = ggml_rope_ext(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
14151 } else {
14152 tensor_clone = ggml_rope_ext_back(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
14153 }
14154 }
14155 } else if (tensor->op == GGML_OP_UNARY) {
14156 switch (ggml_get_unary_op(tensor)) {
14157 case GGML_UNARY_OP_EXP:
14158 tensor_clone = ggml_exp(ggml_ctx, src_clone[0]);
14159 break;
14160 case GGML_UNARY_OP_SILU:
14161 tensor_clone = ggml_silu(ggml_ctx, src_clone[0]);
14162 break;
14163 case GGML_UNARY_OP_GELU:
14164 tensor_clone = ggml_gelu(ggml_ctx, src_clone[0]);
14165 break;
14166 case GGML_UNARY_OP_GELU_ERF:
14167 tensor_clone = ggml_gelu_erf(ggml_ctx, src_clone[0]);
14168 break;
14169 case GGML_UNARY_OP_GELU_QUICK:
14170 tensor_clone = ggml_gelu_quick(ggml_ctx, src_clone[0]);
14171 break;
14172 case GGML_UNARY_OP_RELU:
14173 tensor_clone = ggml_relu(ggml_ctx, src_clone[0]);
14174 break;
14175 case GGML_UNARY_OP_TANH:
14176 tensor_clone = ggml_tanh(ggml_ctx, src_clone[0]);
14177 break;
14178 case GGML_UNARY_OP_SIGMOID:
14179 tensor_clone = ggml_sigmoid(ggml_ctx, src_clone[0]);
14180 break;
14181 case GGML_UNARY_OP_HARDSIGMOID:
14182 tensor_clone = ggml_hardsigmoid(ggml_ctx, src_clone[0]);
14183 break;
14184 case GGML_UNARY_OP_HARDSWISH:
14185 tensor_clone = ggml_hardswish(ggml_ctx, src_clone[0]);
14186 break;
14187 default:
14188 std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
14189 GGML_ABORT("fatal error");
14190 }
14191 } else if (tensor->op == GGML_OP_GLU) {
14192 if (src_clone[1] == nullptr) {
14193 tensor_clone = ggml_glu(ggml_ctx, src_clone[0], (ggml_glu_op) tensor->op_params[0], tensor->op_params[1]);
14194 } else {
14195 tensor_clone = ggml_glu_split(ggml_ctx, src_clone[0], src_clone[1], (ggml_glu_op) tensor->op_params[0]);
14196 }
14197 ggml_set_op_params_i32(tensor_clone, 2, ggml_get_op_params_i32(tensor, 2));
14198 ggml_set_op_params_i32(tensor_clone, 3, ggml_get_op_params_i32(tensor, 3));
14199 } else if (tensor->op == GGML_OP_CPY || tensor->op == GGML_OP_DUP) {
14200 if (tensor->src[1] == nullptr) {
14201 tensor_clone = ggml_dup(ggml_ctx, src_clone[0]);
14202 tensor_clone->type = tensor->type;
14203 } else {
14204 tensor_clone = ggml_cpy(ggml_ctx, src_clone[0], src_clone[1]);
14205 }
14206 } else if (tensor->op == GGML_OP_CONT) {
14207 tensor_clone = ggml_cont_4d(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
14208 } else if (tensor->op == GGML_OP_RESHAPE) {
14209 tensor_clone = ggml_reshape_4d(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
14210 } else if (tensor->op == GGML_OP_VIEW) {
14211 tensor_clone = ggml_view_4d(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], tensor->nb[1], tensor->nb[2], tensor->nb[3], ((int32_t *) tensor->op_params)[0]);
14212 } else if (tensor->op == GGML_OP_PERMUTE) {
14213 int32_t * params = (int32_t *)tensor->op_params;
14214 tensor_clone = ggml_permute(ggml_ctx, src_clone[0], params[0], params[1], params[2], params[3]);
14215 } else if (tensor->op == GGML_OP_TRANSPOSE) {
14216 tensor_clone = ggml_transpose(ggml_ctx, src_clone[0]);
14217 } else if (tensor->op == GGML_OP_GET_ROWS) {
14218 tensor_clone = ggml_get_rows(ggml_ctx, src_clone[0], src_clone[1]);
14219 } else if (tensor->op == GGML_OP_ARGSORT) {
14220 tensor_clone = ggml_argsort(ggml_ctx, src_clone[0], (ggml_sort_order) *(int *)tensor->op_params);
14221 } else if (tensor->op == GGML_OP_SUM) {
14222 tensor_clone = ggml_sum(ggml_ctx, src_clone[0]);
14223 } else if (tensor->op == GGML_OP_SUM_ROWS) {
14224 tensor_clone = ggml_sum_rows(ggml_ctx, src_clone[0]);
14225 } else if (tensor->op == GGML_OP_MEAN) {
14226 tensor_clone = ggml_mean(ggml_ctx, src_clone[0]);
14227 } else if (tensor->op == GGML_OP_ARGMAX) {
14228 tensor_clone = ggml_argmax(ggml_ctx, src_clone[0]);
14229 } else if (tensor->op == GGML_OP_COUNT_EQUAL) {
14230 tensor_clone = ggml_count_equal(ggml_ctx, src_clone[0], src_clone[1]);
14231 } else if (tensor->op == GGML_OP_IM2COL) {
14232 const int32_t s0 = tensor->op_params[0];
14233 const int32_t s1 = tensor->op_params[1];
14234 const int32_t p0 = tensor->op_params[2];
14235 const int32_t p1 = tensor->op_params[3];
14236 const int32_t d0 = tensor->op_params[4];
14237 const int32_t d1 = tensor->op_params[5];
14238
14239 const bool is_2D = tensor->op_params[6] == 1;
14240 tensor_clone = ggml_im2col(ggml_ctx, src_clone[0], src_clone[1], s0, s1, p0, p1, d0, d1, is_2D, tensor->type);
14241 } else if (tensor->op == GGML_OP_IM2COL_3D) {
14242 const int32_t s0 = tensor->op_params[0];
14243 const int32_t s1 = tensor->op_params[1];
14244 const int32_t s2 = tensor->op_params[2];
14245 const int32_t p0 = tensor->op_params[3];
14246 const int32_t p1 = tensor->op_params[4];
14247 const int32_t p2 = tensor->op_params[5];
14248 const int32_t d0 = tensor->op_params[6];
14249 const int32_t d1 = tensor->op_params[7];
14250 const int32_t d2 = tensor->op_params[8];
14251 const int32_t IC = tensor->op_params[9];
14252
14253 tensor_clone = ggml_im2col_3d(ggml_ctx, src_clone[0], src_clone[1], IC, s0, s1, s2, p0, p1, p2, d0, d1, d2, tensor->type);
14254 } else if (tensor->op == GGML_OP_TIMESTEP_EMBEDDING) {
14255 const int32_t dim = tensor->op_params[0];
14256 const int32_t max_period = tensor->op_params[1];
14257 tensor_clone = ggml_timestep_embedding(ggml_ctx, src_clone[0], dim, max_period);
14258 } else if (tensor->op == GGML_OP_CONV_TRANSPOSE_1D){
14259 const int32_t s0 = tensor->op_params[0];
14260 const int32_t p0 = tensor->op_params[1];
14261 const int32_t d0 = tensor->op_params[2];
14262 tensor_clone = ggml_conv_transpose_1d(ggml_ctx, src_clone[0], src_clone[1], s0, p0, d0);
14263 } else if (tensor->op == GGML_OP_POOL_2D) {
14264 enum ggml_op_pool op = static_cast<ggml_op_pool>(tensor->op_params[0]);
14265 const int32_t k0 = tensor->op_params[1];
14266 const int32_t k1 = tensor->op_params[2];
14267 const int32_t s0 = tensor->op_params[3];
14268 const int32_t s1 = tensor->op_params[4];
14269 const int32_t p0 = tensor->op_params[5];
14270 const int32_t p1 = tensor->op_params[6];
14271
14272 tensor_clone = ggml_pool_2d(ggml_ctx, src_clone[0], op, k0, k1, s0, s1, p0, p1);
14273 } else if (tensor->op == GGML_OP_CONV_2D) {
14274 const int32_t s0 = tensor->op_params[0];
14275 const int32_t s1 = tensor->op_params[1];
14276 const int32_t p0 = tensor->op_params[2];
14277 const int32_t p1 = tensor->op_params[3];
14278 const int32_t d0 = tensor->op_params[4];
14279 const int32_t d1 = tensor->op_params[5];
14280 tensor_clone = ggml_conv_2d(ggml_ctx, src_clone[0], src_clone[1], s0, s1, p0, p1, d0, d1);
14281 } else if (tensor->op == GGML_OP_CONV_2D_DW) {
14282 const int32_t s0 = tensor->op_params[0];
14283 const int32_t s1 = tensor->op_params[1];
14284 const int32_t p0 = tensor->op_params[2];
14285 const int32_t p1 = tensor->op_params[3];
14286 const int32_t d0 = tensor->op_params[4];
14287 const int32_t d1 = tensor->op_params[5];
14288 tensor_clone = ggml_conv_2d_dw_direct(ggml_ctx, src_clone[0], src_clone[1], s0, s1, p0, p1, d0, d1);
14289 } else if (tensor->op == GGML_OP_CONV_TRANSPOSE_2D) {
14290 const int32_t s = tensor->op_params[0];
14291 tensor_clone = ggml_conv_transpose_2d_p0(ggml_ctx, src_clone[0], src_clone[1], s);
14292 } else if (tensor->op == GGML_OP_LEAKY_RELU) {
14293 const float * op_params = (const float *)tensor->op_params;
14294 tensor_clone = ggml_leaky_relu(ggml_ctx, src_clone[0], op_params[0], false);
14295 } else if (tensor->op == GGML_OP_RWKV_WKV6) {
14296 tensor_clone = ggml_rwkv_wkv6(ggml_ctx, src_clone[0], src_clone[1],
14297 src_clone[2], src_clone[3], src_clone[4], src_clone[5]);
14298 } else if (tensor->op == GGML_OP_RWKV_WKV7) {
14299 tensor_clone = ggml_rwkv_wkv7(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3],
14300 src_clone[4], src_clone[5], src_clone[6]);
14301 } else if (tensor->op == GGML_OP_OPT_STEP_ADAMW) {
14302 src_clone[0]->flags = tensor->src[0]->flags;
14303 tensor_clone = ggml_opt_step_adamw(ggml_ctx, src_clone[0], src_clone[1],
14304 src_clone[2], src_clone[3], src_clone[4]);
14305 } else if (tensor->op == GGML_OP_OPT_STEP_SGD) {
14306 src_clone[0]->flags = tensor->src[0]->flags;
14307 tensor_clone = ggml_opt_step_sgd(ggml_ctx, src_clone[0], src_clone[1],
14308 src_clone[2]);
14309 } else if (tensor->op == GGML_OP_ADD_ID) {
14310 tensor_clone = ggml_add_id(ggml_ctx, src_clone[0], src_clone[1], src_clone[2]);
14311 } else if (tensor->op == GGML_OP_SSM_SCAN) {
14312 tensor_clone = ggml_ssm_scan(ggml_ctx, src_clone[0], src_clone[1], src_clone[2],
14313 src_clone[3], src_clone[4], src_clone[5], src_clone[6]);
14314 } else if (tensor->op == GGML_OP_SSM_CONV) {
14315 tensor_clone = ggml_ssm_conv(ggml_ctx, src_clone[0], src_clone[1]);
14316 } else if (tensor->op == GGML_OP_ROLL) {
14317 const int32_t s0 = tensor->op_params[0];
14318 const int32_t s1 = tensor->op_params[1];
14319 const int32_t s2 = tensor->op_params[2];
14320 const int32_t s3 = tensor->op_params[3];
14321 tensor_clone = ggml_roll(ggml_ctx, src_clone[0], s0, s1, s2, s3);
14322 }
14323 else {
14324 std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
14325 GGML_ABORT("fatal error");
14326 }
14327 cloned_tensors[tensor] = tensor_clone;
14328 }
14329
14330 ggml_cgraph * cgraph_cpu = ggml_new_graph(ggml_ctx);
14331 ggml_build_forward_expand(cgraph_cpu, tensor_clone);
14332
14333 ggml_graph_compute_with_ctx(ggml_ctx, cgraph_cpu, 8);
14334
14335 if (vk_output_tensor > 0 && vk_output_tensor == check_counter) {
14336 ggml_vk_print_tensor(tensor_clone, "tensor_clone");
14337 }
14338
14339 comp_size = ggml_nbytes(tensor_clone);
14340
14341 comp_result = malloc(comp_size);
14342 memcpy(comp_result, tensor_clone->data, comp_size);
14343 memcpy(comp_nb, tensor_clone->nb, sizeof(size_t) * GGML_MAX_DIMS);
14344
14345 for (auto m : cloned_mallocs) {
14346 free(m);
14347 }
14348
14349 ggml_free(ggml_ctx);
14350
14351 VK_LOG_DEBUG("END ggml_vk_check_results_0(" << tensor->name << ")");
14352}
14353
14354static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx) {
14355 ggml_tensor * tensor = cgraph->nodes[tensor_idx + ctx->num_additional_fused_ops];
14356 if (tensor->op == GGML_OP_TRANSPOSE || tensor->op == GGML_OP_SET_ROWS) {
14357 return;
14358 }
14359
14360 if (!(vk_output_tensor > 0 && vk_output_tensor == check_counter) && check_counter <= vk_skip_checks) {
14361 return;
14362 }
14363
14364 VK_LOG_DEBUG("ggml_vk_check_results_1(" << tensor->name << ")");
14365
14366 ggml_tensor * src0 = tensor->src[0];
14367 ggml_tensor * src1 = tensor->src[1];
14368 ggml_tensor * src2 = tensor->src[2];
14369 ggml_tensor * src3 = tensor->src[3];
14370
14371 void * tensor_data = tensor->data;
14372
14373 if (ggml_backend_buffer_is_vk(tensor->buffer)) {
14374 size_t tensor_size = ggml_nbytes(tensor);
14375 tensor_data = malloc(tensor_size);
14376
14377 ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context;
14378
14379 vk_buffer& buffer_gpu = buf_ctx->dev_buffer;
14380 uint64_t offset = vk_tensor_offset(tensor) + tensor->view_offs;
14381 if (offset + tensor_size >= buffer_gpu->size) {
14382 tensor_size = buffer_gpu->size - offset;
14383 }
14384
14385 ggml_vk_buffer_read(buffer_gpu, offset, tensor_data, tensor_size);
14386 }
14387
14388 float first_error_result = -1.0f;
14389 float first_error_correct = -1.0f;
14390 std::array<int, 4> first_error = { -1, -1, -1, -1 };
14391 double avg_err = 0.0;
14392 size_t counter = 0;
14393
14394 for (int i3 = 0; i3 < tensor->ne[3]; i3++) {
14395 for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
14396 for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
14397 for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
14398 const bool buffer_size_fit = i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0] < comp_size;
14399 float correct = 0.0f;
14400 float result = 0.0f;
14401
14402 if (buffer_size_fit) {
14403 if (tensor->type == GGML_TYPE_F32) {
14404 correct = *(float *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]);
14405 result = *(float *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]);
14406 } else if (tensor->type == GGML_TYPE_F16) {
14407 correct = ggml_fp16_to_fp32(*(ggml_fp16_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]));
14408 result = ggml_fp16_to_fp32(*(ggml_fp16_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]));
14409 } else if (tensor->type == GGML_TYPE_BF16) {
14410 correct = ggml_bf16_to_fp32(*(ggml_bf16_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]));
14411 result = ggml_bf16_to_fp32(*(ggml_bf16_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]));
14412 } else if (tensor->type == GGML_TYPE_I32) {
14413 correct = *(int32_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]);
14414 result = *(int32_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]);
14415 } else if (tensor->type == GGML_TYPE_I64) {
14416 correct = *(int64_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]);
14417 result = *(int64_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]);
14418 } else {
14419 std::cerr << "Results check not implemented for type " << ggml_type_name(tensor->type) << std::endl;
14420 }
14421 } else {
14422 std::cerr << "Missing debug code for type " << ggml_type_name(tensor->type) << std::endl;
14423 GGML_ABORT("fatal error");
14424 }
14425
14426 if ((std::isnan(correct) != std::isnan(result)) || (std::isinf(correct) != std::isinf(result)) || !buffer_size_fit) {
14427 std::cerr << "ERROR: Invalid value in " << ggml_op_name(tensor->op) << " i3=" << i3 << " i2=" << i2 << " i1=" << i1 << " i0=" << i0 << " result=" << result << " correct=" << correct << " avg_err=" << (avg_err / counter) << std::endl;
14428 std::cerr << "tensor=" << tensor << " tensor->name=" << tensor->name << " tensor->type: " << ggml_type_name(tensor->type) << " ne0=" << tensor->ne[0] << " nb0=" << tensor->nb[0] << " ne1=" << tensor->ne[1] << " nb1=" << tensor->nb[1] << " ne2=" << tensor->ne[2] << " nb2=" << tensor->nb[2] << " ne3=" << tensor->ne[3] << " nb3=" << tensor->nb[3] << " offset=" << tensor->view_offs << std::endl;
14429 if (src0 != nullptr) {
14430 std::cerr << "src0=" << src0 << " src0->name=" << src0->name << " op=" << ggml_op_name(src0->op) << " type=" << ggml_type_name(src0->type) << " ne0=" << src0->ne[0] << " nb0=" << src0->nb[0] << " ne1=" << src0->ne[1] << " nb1=" << src0->nb[1] << " ne2=" << src0->ne[2] << " nb2=" << src0->nb[2] << " ne3=" << src0->ne[3] << " nb3=" << src0->nb[3] << " offset=" << src0->view_offs << std::endl;
14431 }
14432 if (src1 != nullptr) {
14433 std::cerr << "src1=" << src1 << " src1->name=" << src1->name << " op=" << ggml_op_name(src1->op) << " type=" << ggml_type_name(src1->type) << " ne0=" << src1->ne[0] << " nb0=" << src1->nb[0] << " ne1=" << src1->ne[1] << " nb1=" << src1->nb[1] << " ne2=" << src1->ne[2] << " nb2=" << src1->nb[2] << " ne3=" << src1->ne[3] << " nb3=" << src1->nb[3] << " offset=" << src1->view_offs << std::endl;
14434 }
14435 if (src2 != nullptr) {
14436 std::cerr << "src2=" << src2 << " src2->name=" << src2->name << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl;
14437 }
14438 if (src3 != nullptr) {
14439 std::cerr << "src3=" << src3 << " src3->name=" << src3->name << " op=" << ggml_op_name(src3->op) << " type=" << ggml_type_name(src3->type) << " ne0=" << src3->ne[0] << " nb0=" << src3->nb[0] << " ne1=" << src3->ne[1] << " nb1=" << src3->nb[1] << " ne2=" << src3->ne[2] << " nb2=" << src3->nb[2] << " ne3=" << src3->ne[3] << " nb3=" << src3->nb[3] << " offset=" << src3->view_offs << std::endl;
14440 }
14441 std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl;
14442 std::cerr << std::endl << "Result:" << std::endl;
14443 ggml_vk_print_tensor_area(tensor, tensor_data, i0, i1, i2, i3);
14444 std::cerr << std::endl << "Correct:" << std::endl;
14445 ggml_vk_print_tensor_area(tensor, comp_result, i0, i1, i2, i3);
14446 std::cerr << std::endl;
14447 std::vector<const ggml_tensor *> done;
14448 ggml_vk_print_graph_origin(tensor, done);
14449 GGML_ABORT("fatal error");
14450 }
14451 const double denom = std::fabs(correct) > 1.0f ? (std::fabs(correct) > 1e-8 ? std::fabs(correct) : 1e-8) : 1.0f;
14452 if (first_error[0] == -1 && std::fabs(correct - result) / denom > 0.5) {
14453 first_error[0] = i0;
14454 first_error[1] = i1;
14455 first_error[2] = i2;
14456 first_error[3] = i3;
14457 first_error_result = result;
14458 first_error_correct = correct;
14459 }
14460
14461 // Special case, value is infinite, avoid NaN result in avg_err
14462 // NaN also appears in results, if both are nan error is 0
14463 if (!std::isinf(correct) && !std::isinf(result) && !std::isnan(correct) && !std::isnan(result)) {
14464 avg_err += std::fabs(correct - result) / denom;
14465 }
14466 counter++;
14467 }
14468 }
14469 }
14470 }
14471
14472 avg_err /= counter;
14473
14474 if (vk_output_tensor > 0 && vk_output_tensor == check_counter) {
14475 std::cerr << "TENSOR CHECK: avg_err=" << avg_err << " in " << ggml_op_name(tensor->op) << " (check " << check_counter << ")" << std::endl;
14476 std::cerr << "tensor=" << tensor << " tensor->name=" << tensor->name << " tensor->type: " << ggml_type_name(tensor->type) << " ne0=" << tensor->ne[0] << " nb0=" << tensor->nb[0] << " ne1=" << tensor->ne[1] << " nb1=" << tensor->nb[1] << " ne2=" << tensor->ne[2] << " nb2=" << tensor->nb[2] << " ne3=" << tensor->ne[3] << " nb3=" << tensor->nb[3] << " offset=" << tensor->view_offs << std::endl;
14477 if (src0 != nullptr) {
14478 std::cerr << "src0=" << src0 << " op=" << ggml_op_name(src0->op) << " type=" << ggml_type_name(src0->type) << " ne0=" << src0->ne[0] << " nb0=" << src0->nb[0] << " ne1=" << src0->ne[1] << " nb1=" << src0->nb[1] << " ne2=" << src0->ne[2] << " nb2=" << src0->nb[2] << " ne3=" << src0->ne[3] << " nb3=" << src0->nb[3] << " offset=" << src0->view_offs << std::endl;
14479 }
14480 if (src1 != nullptr) {
14481 std::cerr << "src1=" << src1 << " op=" << ggml_op_name(src1->op) << " type=" << ggml_type_name(src1->type) << " ne0=" << src1->ne[0] << " nb0=" << src1->nb[0] << " ne1=" << src1->ne[1] << " nb1=" << src1->nb[1] << " ne2=" << src1->ne[2] << " nb2=" << src1->nb[2] << " ne3=" << src1->ne[3] << " nb3=" << src1->nb[3] << " offset=" << src1->view_offs << std::endl;
14482 }
14483 if (src2 != nullptr) {
14484 std::cerr << "src2=" << src2 << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl;
14485 }
14486 if (src3 != nullptr) {
14487 std::cerr << "src3=" << src3 << " op=" << ggml_op_name(src3->op) << " type=" << ggml_type_name(src3->type) << " ne0=" << src3->ne[0] << " nb0=" << src3->nb[0] << " ne1=" << src3->ne[1] << " nb1=" << src3->nb[1] << " ne2=" << src3->ne[2] << " nb2=" << src3->nb[2] << " ne3=" << src3->ne[3] << " nb3=" << src3->nb[3] << " offset=" << src3->view_offs << std::endl;
14488 }
14489 std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl;
14490 std::cerr << std::endl << "Result:" << std::endl;
14491 ggml_vk_print_tensor_area(tensor, tensor_data, 5, 5, 0, 0);
14492 std::cerr << std::endl << "Correct:" << std::endl;
14493 ggml_vk_print_tensor_area(tensor, comp_result, 5, 5, 0, 0);
14494 std::cerr << std::endl;
14495 std::vector<const ggml_tensor *> done;
14496 ggml_vk_print_graph_origin(tensor, done);
14497 }
14498
14499 if (avg_err > 0.5 || std::isnan(avg_err)) {
14500 std::cerr << "ERROR: avg_err=" << avg_err << " in " << ggml_op_name(tensor->op) << " (check " << check_counter << ")" << std::endl;
14501 std::cerr << "tensor=" << tensor << " tensor->name=" << tensor->name << " tensor->type: " << ggml_type_name(tensor->type) << " ne0=" << tensor->ne[0] << " nb0=" << tensor->nb[0] << " ne1=" << tensor->ne[1] << " nb1=" << tensor->nb[1] << " ne2=" << tensor->ne[2] << " nb2=" << tensor->nb[2] << " ne3=" << tensor->ne[3] << " nb3=" << tensor->nb[3] << " offset=" << tensor->view_offs << std::endl;
14502 if (src0 != nullptr) {
14503 std::cerr << "src0=" << src0 << " op=" << ggml_op_name(src0->op) << " type=" << ggml_type_name(src0->type) << " ne0=" << src0->ne[0] << " nb0=" << src0->nb[0] << " ne1=" << src0->ne[1] << " nb1=" << src0->nb[1] << " ne2=" << src0->ne[2] << " nb2=" << src0->nb[2] << " ne3=" << src0->ne[3] << " nb3=" << src0->nb[3] << " offset=" << src0->view_offs << std::endl;
14504 }
14505 if (src1 != nullptr) {
14506 std::cerr << "src1=" << src1 << " op=" << ggml_op_name(src1->op) << " type=" << ggml_type_name(src1->type) << " ne0=" << src1->ne[0] << " nb0=" << src1->nb[0] << " ne1=" << src1->ne[1] << " nb1=" << src1->nb[1] << " ne2=" << src1->ne[2] << " nb2=" << src1->nb[2] << " ne3=" << src1->ne[3] << " nb3=" << src1->nb[3] << " offset=" << src1->view_offs << std::endl;
14507 }
14508 if (src2 != nullptr) {
14509 std::cerr << "src2=" << src2 << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl;
14510 }
14511 if (src3 != nullptr) {
14512 std::cerr << "src3=" << src3 << " op=" << ggml_op_name(src3->op) << " type=" << ggml_type_name(src3->type) << " ne0=" << src3->ne[0] << " nb0=" << src3->nb[0] << " ne1=" << src3->ne[1] << " nb1=" << src3->nb[1] << " ne2=" << src3->ne[2] << " nb2=" << src3->nb[2] << " ne3=" << src3->ne[3] << " nb3=" << src3->nb[3] << " offset=" << src3->view_offs << std::endl;
14513 }
14514 std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl;
14515 std::cerr << std::endl << "Result:" << std::endl;
14516 ggml_vk_print_tensor_area(tensor, tensor_data, first_error[0], first_error[1], first_error[2], first_error[3]);
14517 std::cerr << std::endl << "Correct:" << std::endl;
14518 ggml_vk_print_tensor_area(tensor, comp_result, first_error[0], first_error[1], first_error[2], first_error[3]);
14519 std::cerr << std::endl;
14520 std::vector<const ggml_tensor *> done;
14521 ggml_vk_print_graph_origin(tensor, done);
14522 GGML_ABORT("fatal error");
14523 } else {
14524 std::cerr << check_counter << " " << tensor->name << " op=" << ggml_op_name(tensor->op) << " avg_err=" << avg_err << std::endl;
14525 }
14526
14527 free(comp_result);
14528 comp_result = nullptr;
14529 comp_size = 0;
14530
14531 if (ggml_backend_buffer_is_vk(tensor->buffer)) {
14532 free(tensor_data);
14533 }
14534
14535 VK_LOG_DEBUG("END ggml_vk_check_results_1(" << tensor->name << ")");
14536}
14537#endif
14538
14539GGML_BACKEND_DL_IMPL(ggml_backend_vk_reg)
14540